###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Schema validation and migration utilities.
This code interacts directly with the database, outside of the ORM,
taking a `Profile` as input for the connection configuration.
.. important:: This code should only be accessed via the storage backend class, not directly!
"""
from __future__ import annotations
import contextlib
import pathlib
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional
from alembic.command import downgrade, upgrade
from alembic.config import Config
from alembic.runtime.environment import EnvironmentContext
from alembic.runtime.migration import MigrationContext, MigrationInfo
from alembic.script import ScriptDirectory
from sqlalchemy import MetaData, String, column, desc, insert, inspect, select, table
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import Session
from aiida.common import exceptions
from aiida.manage.configuration.profile import Profile
from aiida.storage.log import MIGRATE_LOGGER
from aiida.storage.psql_dos.models.settings import DbSetting
from aiida.storage.psql_dos.utils import create_sqlalchemy_engine
if TYPE_CHECKING:
from disk_objectstore import Container
TEMPLATE_LEGACY_DJANGO_SCHEMA = """
Database schema is using the legacy Django schema.
To migrate the database schema version to the current one, run the following command:
verdi -p {profile_name} storage migrate
"""
TEMPLATE_INVALID_SCHEMA_VERSION = """
Database schema version `{schema_version_database}` is incompatible with the required schema version `{schema_version_code}`.
To migrate the database schema version to the current one, run the following command:
verdi -p {profile_name} storage migrate
""" # noqa: E501
ALEMBIC_REL_PATH = 'migrations'
REPOSITORY_UUID_KEY = 'repository|uuid'
[docs]
class PsqlDosMigrator:
"""Class for validating and migrating `psql_dos` storage instances.
.. important:: This class should only be accessed via the storage backend class (apart from for test purposes)
"""
alembic_version_tbl_name = 'alembic_version'
django_version_table = table(
'django_migrations', column('id'), column('app', String(255)), column('name', String(255)), column('applied')
)
[docs]
def __init__(self, profile: Profile) -> None:
self.profile = profile
self._engine = create_sqlalchemy_engine(self.profile.storage_config)
self._connection = None
[docs]
def close(self) -> None:
"""Close the connection if it was opened and dispose of the engine."""
if self._connection:
self._connection.close()
self._connection = None
if self._engine:
self._engine.dispose()
self._engine = None
@property
def connection(self):
"""Return the connection to the database.
Will automatically create the engine and open an connection if not already opened in a previous call.
:return: Open connection to the database.
:raises: :class:`aiida.common.exceptions.UnreachableStorage` if connecting to the database fails.
"""
if self._connection is None:
try:
self._connection = self._engine.connect()
except OperationalError as exception:
raise exceptions.UnreachableStorage(f'Could not connect to database: {exception}') from exception
return self._connection
[docs]
@classmethod
def get_schema_versions(cls) -> Dict[str, str]:
"""Return all available schema versions (oldest to latest).
:return: schema version -> description
"""
return {entry.revision: entry.doc for entry in reversed(list(cls._alembic_script().walk_revisions()))}
[docs]
@classmethod
def get_schema_version_head(cls) -> str:
"""Return the head schema version for this storage, i.e. the latest schema this storage can be migrated to."""
return cls._alembic_script().revision_map.get_current_head('main')
[docs]
def get_schema_version_profile(self, check_legacy=False) -> Optional[str]:
"""Return the schema version of the backend instance for this profile.
Note, the version will be None if the database is empty or is a legacy django database.
"""
with self._migration_context() as context:
version = context.get_current_revision()
if version is None and check_legacy:
stmt = select(self.django_version_table.c.name).where(self.django_version_table.c.app == 'db')
stmt = stmt.order_by(desc(self.django_version_table.c.id)).limit(1)
try:
return self.connection.execute(stmt).scalar()
except (OperationalError, ProgrammingError):
self.connection.rollback()
return version
[docs]
def validate_storage(self) -> None:
"""Validate that the storage for this profile
1. That the database schema is at the head version, i.e. is compatible with the code API.
2. That the repository ID is equal to the UUID set in the database
:raises: :class:`aiida.common.exceptions.UnreachableStorage` if the storage cannot be connected to
:raises: :class:`aiida.common.exceptions.IncompatibleStorageSchema`
if the storage is not compatible with the code API.
:raises: :class:`aiida.common.exceptions.CorruptStorage`
if the repository ID is not equal to the UUID set in thedatabase.
"""
# check there is an alembic_version table from which to get the schema version
if not inspect(self.connection).has_table(self.alembic_version_tbl_name):
# if not present, it might be that this is a legacy django database
if inspect(self.connection).has_table(self.django_version_table.name):
raise exceptions.IncompatibleStorageSchema(
TEMPLATE_LEGACY_DJANGO_SCHEMA.format(profile_name=self.profile.name)
)
raise exceptions.IncompatibleStorageSchema('The database has no known version.')
# now we can check that the alembic version is the latest
schema_version_code = self.get_schema_version_head()
schema_version_database = self.get_schema_version_profile(check_legacy=False)
if schema_version_database != schema_version_code:
raise exceptions.IncompatibleStorageSchema(
TEMPLATE_INVALID_SCHEMA_VERSION.format(
schema_version_database=schema_version_database,
schema_version_code=schema_version_code,
profile_name=self.profile.name,
)
)
# finally, we check that the ID set within the disk-objectstore is equal to the one saved in the database,
# i.e. this container is indeed the one associated with the db
repository_uuid = self.get_repository_uuid()
stmt = select(DbSetting.val).where(DbSetting.key == REPOSITORY_UUID_KEY)
database_repository_uuid = self.connection.execute(stmt).scalar_one_or_none()
if database_repository_uuid is None:
raise exceptions.CorruptStorage('The database has no repository UUID set.')
if database_repository_uuid != repository_uuid:
raise exceptions.CorruptStorage(
f'The database has a repository UUID configured to {database_repository_uuid} '
f"but the disk-objectstore's is {repository_uuid}."
)
[docs]
def get_container(self) -> 'Container':
"""Return the disk-object store container.
:returns: The disk-object store container configured for the repository path of the current profile.
"""
from disk_objectstore import Container
from .backend import get_filepath_container
return Container(get_filepath_container(self.profile))
[docs]
def get_repository_uuid(self) -> str:
"""Return the UUID of the repository.
:returns: The repository UUID.
:raises: :class:`~aiida.common.exceptions.UnreachableStorage` if the UUID cannot be retrieved, which probably
means that the repository is not initialised.
"""
try:
return self.get_container().container_id
except Exception as exception:
raise exceptions.UnreachableStorage(
f'Could not access disk-objectstore {self.get_container()}: {exception}'
) from exception
[docs]
def initialise(self, reset: bool = False) -> bool:
"""Initialise the storage backend.
This is typically used once when a new storage backed is created. If this method returns without exceptions the
storage backend is ready for use. If the backend already seems initialised, this method is a no-op.
:param reset: If ``true``, destroy the backend if it already exists including all of its data before recreating
and initialising it. This is useful for example for test profiles that need to be reset before or after
tests having run.
:returns: ``True`` if the storage was initialised by the function call, ``False`` if it was already initialised.
"""
if reset:
self.reset_repository()
self.reset_database()
initialised: bool = False
if not self.is_initialised:
self.initialise_repository()
self.initialise_database()
initialised = True
# Call migrate in the case the storage was already initialised but not yet at the latest schema version. If it
# was, then the following is a no-op anyway.
self.migrate()
return initialised
@property
def is_initialised(self) -> bool:
"""Return whether the storage is initialised.
This is the case if both the database and the repository are initialised.
:returns: ``True`` if the storage is initialised, ``False`` otherwise.
"""
return self.is_repository_initialised and self.is_database_initialised
@property
def is_repository_initialised(self) -> bool:
"""Return whether the repository is initialised.
:returns: ``True`` if the repository is initialised, ``False`` otherwise.
"""
return self.get_container().is_initialised
@property
def is_database_initialised(self) -> bool:
"""Return whether the database is initialised.
This is the case if it contains the table that holds the schema version for alembic or Django.
:returns: ``True`` if the database is initialised, ``False`` otherwise.
"""
return inspect(self.connection).has_table(self.alembic_version_tbl_name) or inspect(self.connection).has_table(
self.django_version_table.name
)
[docs]
def reset_repository(self) -> None:
"""Reset the repository by deleting all of its contents.
This will also destroy the configuration and so in order to use it again, it will have to be reinitialised.
"""
import shutil
try:
shutil.rmtree(self.get_container().get_folder())
except FileNotFoundError:
pass
[docs]
def reset_database(self) -> None:
"""Reset the database by deleting all content from all tables.
This will also destroy the settings table and so in order to use it again, it will have to be reinitialised.
"""
self.delete_all_tables(exclude_tables=[self.alembic_version_tbl_name])
[docs]
def initialise_repository(self) -> None:
"""Initialise the repository."""
from aiida.storage.psql_dos.backend import CONTAINER_DEFAULTS
container = self.get_container()
container.init_container(clear=True, **CONTAINER_DEFAULTS)
[docs]
def initialise_database(self) -> None:
"""Initialise the database.
This assumes that the database has no schema whatsoever and so the initial schema is created directly from the
models at the current head version without migrating through all of them one by one.
"""
from aiida.storage.psql_dos.models.base import get_orm_metadata
# setup the database
# see: https://alembic.sqlalchemy.org/en/latest/cookbook.html#building-an-up-to-date-database-from-scratch
MIGRATE_LOGGER.report('initialising empty storage schema')
get_orm_metadata().create_all(self._engine)
repository_uuid = self.get_repository_uuid()
# Create a "sync" between the database and repository, by saving its UUID in the settings table
# this allows us to validate inconsistencies between the two
self.connection.execute(
insert(DbSetting).values(key=REPOSITORY_UUID_KEY, val=repository_uuid, description='Repository UUID')
)
# finally, generate the version table, "stamping" it with the most recent revision
with self._migration_context() as context:
context.stamp(context.script, 'main@head')
self.connection.commit()
[docs]
def delete_all_tables(self, *, exclude_tables: list[str] | None = None) -> None:
"""Delete all tables of the current database schema.
The tables are determined dynamically through reflection of the current schema version. Any other tables in the
database that are not part of the schema should remain unaffected.
:param exclude_tables: Optional list of table names that should not be deleted.
"""
exclude_tables = exclude_tables or []
if inspect(self.connection).has_table(self.alembic_version_tbl_name):
metadata = MetaData()
metadata.reflect(bind=self.connection)
# The ``sorted_tables`` property returns the tables sorted by their foreign-key dependencies, with those
# that are dependent on others first. Iterate over the list in reverse to ensure that the tables with
# the independent rows are deleted first.
for schema_table in reversed(metadata.sorted_tables):
if schema_table.name in exclude_tables:
continue
self.connection.execute(schema_table.delete())
self.connection.commit()
[docs]
def migrate(self) -> None:
"""Migrate the storage for this profile to the head version.
:raises: :class:`~aiida.common.exceptions.UnreachableStorage` if the storage cannot be accessed.
:raises: :class:`~aiida.common.exceptions.StorageMigrationError` if the storage is not initialised.
"""
# The database can be in one of a few states:
# 1. Legacy django database -> we transfer the version to alembic, migrate to the head of the django branch,
# reset the revision as one on the main branch, and then migrate to the head of the main branch
# 2. Legacy sqlalchemy database -> we migrate to the head of the sqlalchemy branch,
# reset the revision as one on the main branch, and then migrate to the head of the main branch
# 3. Already on the main branch -> we migrate to the head of the main branch
if not inspect(self.connection).has_table(self.alembic_version_tbl_name):
if not inspect(self.connection).has_table(self.django_version_table.name):
raise exceptions.StorageMigrationError('storage is uninitialised, cannot migrate.')
# the database is a legacy django one,
# so we need to copy the version from the 'django_migrations' table to the 'alembic_version' one
legacy_version = self.get_schema_version_profile(check_legacy=True)
if legacy_version is None:
raise exceptions.StorageMigrationError(
'No schema version could be read from the database. '
"Check that either the 'alembic_version' or 'django_migrations' tables "
'are present and accessible, using e.g. `verdi devel run-sql "SELECT * FROM alembic_version"`'
)
# the version should be of the format '00XX_description'
version = f'django_{legacy_version[:4]}'
with self._migration_context() as context:
context.stamp(context.script, version)
self.connection.commit()
# now we can continue with the migration as normal
else:
version = self.get_schema_version_profile()
# find what branch the current version is on
branches = self._alembic_script().revision_map.get_revision(version).branch_labels
if 'django' in branches or 'sqlalchemy' in branches:
# migrate up to the top of the respective legacy branches
if 'django' in branches:
MIGRATE_LOGGER.report('Migrating to the head of the legacy django branch')
self.migrate_up('django@head')
elif 'sqlalchemy' in branches:
MIGRATE_LOGGER.report('Migrating to the head of the legacy sqlalchemy branch')
self.migrate_up('sqlalchemy@head')
# now re-stamp with the comparable revision on the main branch
with self._migration_context() as context:
context._ensure_version_table(purge=True)
context.stamp(context.script, 'main_0001')
self.connection.commit()
# finally migrate to the main head revision
MIGRATE_LOGGER.report('Migrating to the head of the main branch')
self.migrate_up('main@head')
self.connection.commit()
[docs]
def migrate_up(self, version: str) -> None:
"""Migrate the database up to a specific version.
:param version: string with schema version to migrate to
"""
with self._alembic_connect() as config:
upgrade(config, version)
[docs]
def migrate_down(self, version: str) -> None:
"""Migrate the database down to a specific version.
:param version: string with schema version to migrate to
"""
with self._alembic_connect() as config:
downgrade(config, version)
[docs]
@staticmethod
def _alembic_config():
"""Return an instance of an Alembic `Config`."""
dirpath = pathlib.Path(__file__).resolve().parent
config = Config()
config.set_main_option('script_location', str(dirpath / ALEMBIC_REL_PATH))
return config
[docs]
@classmethod
def _alembic_script(cls):
"""Return an instance of an Alembic `ScriptDirectory`."""
return ScriptDirectory.from_config(cls._alembic_config())
[docs]
@contextlib.contextmanager
def _alembic_connect(self) -> Iterator[Config]:
"""Context manager to return an instance of an Alembic configuration.
The profiles's database connection is added in the `attributes` property, through which it can then also be
retrieved, also in the `env.py` file, which is run when the database is migrated.
"""
config = self._alembic_config()
config.attributes['connection'] = self.connection
config.attributes['aiida_profile'] = self.profile
def _callback(step: MigrationInfo, **kwargs):
"""Callback to be called after a migration step is executed."""
from_rev = step.down_revision_ids[0] if step.down_revision_ids else '<base>'
MIGRATE_LOGGER.report(f'- {from_rev} -> {step.up_revision_id}')
config.attributes['on_version_apply'] = _callback
yield config
[docs]
@contextlib.contextmanager
def _migration_context(self) -> Iterator[MigrationContext]:
"""Context manager to return an instance of an Alembic migration context.
This migration context will have been configured with the current database connection, which allows this context
to be used to inspect the contents of the database, such as the current revision.
"""
with self._alembic_connect() as config:
script = ScriptDirectory.from_config(config)
with EnvironmentContext(config, script) as context:
context.configure(context.config.attributes['connection'])
yield context.get_context()
# the following are used for migration tests
[docs]
@contextlib.contextmanager
def session(self) -> Iterator[Session]:
"""Context manager to return a session for the database."""
session = Session(self._engine, future=True)
try:
yield session
except Exception:
session.rollback()
raise
finally:
session.close()
[docs]
def get_current_table(self, table_name: str) -> Any:
"""Return a table instantiated at the correct migration.
Note that this is obtained by inspecting the database and not by looking into the models file.
So, special methods possibly defined in the models files/classes are not present.
"""
base = automap_base()
base.prepare(autoload_with=self.connection.engine)
return getattr(base.classes, table_name)