###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Migration from the "legacy" JSON format, to an sqlite database, and node uuid based repository to hash based."""
import shutil
import tarfile
from contextlib import contextmanager
from datetime import datetime
from hashlib import sha256
from pathlib import Path, PurePosixPath
from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Tuple, Union
from archive_path import ZipPath
from sqlalchemy import insert, select
from sqlalchemy.exc import IntegrityError
from aiida.common.exceptions import CorruptStorage, StorageMigrationError
from aiida.common.hashing import chunked_file_hash
from aiida.common.progress_reporter import get_progress_reporter
from aiida.repository.common import File, FileType
from aiida.storage.log import MIGRATE_LOGGER
from ..utils import DB_FILENAME, REPO_FOLDER, create_sqla_engine
from . import v1_db_schema as v1_schema
from .utils import update_metadata
_NODE_ENTITY_NAME = 'Node'
_GROUP_ENTITY_NAME = 'Group'
_COMPUTER_ENTITY_NAME = 'Computer'
_USER_ENTITY_NAME = 'User'
_LOG_ENTITY_NAME = 'Log'
_COMMENT_ENTITY_NAME = 'Comment'
file_fields_to_model_fields: Dict[str, Dict[str, str]] = {
_NODE_ENTITY_NAME: {'dbcomputer': 'dbcomputer_id', 'user': 'user_id'},
_GROUP_ENTITY_NAME: {'user': 'user_id'},
_COMPUTER_ENTITY_NAME: {},
_LOG_ENTITY_NAME: {'dbnode': 'dbnode_id'},
_COMMENT_ENTITY_NAME: {'dbnode': 'dbnode_id', 'user': 'user_id'},
}
aiida_orm_to_backend = {
_USER_ENTITY_NAME: v1_schema.DbUser,
_GROUP_ENTITY_NAME: v1_schema.DbGroup,
_NODE_ENTITY_NAME: v1_schema.DbNode,
_COMMENT_ENTITY_NAME: v1_schema.DbComment,
_COMPUTER_ENTITY_NAME: v1_schema.DbComputer,
_LOG_ENTITY_NAME: v1_schema.DbLog,
}
LEGACY_TO_MAIN_REVISION = 'main_0000'
[docs]
def _json_to_sqlite(
outpath: Path, data: dict, node_repos: Dict[str, List[Tuple[str, Optional[str]]]], batch_size: int = 100
) -> None:
"""Convert a JSON archive format to SQLite."""
from aiida.tools.archive.common import batch_iter
MIGRATE_LOGGER.report('Converting DB to SQLite')
engine = create_sqla_engine(outpath)
v1_schema.ArchiveV1Base.metadata.create_all(engine)
with engine.begin() as connection:
# proceed in order of relationships
for entity_type in (
_USER_ENTITY_NAME,
_COMPUTER_ENTITY_NAME,
_GROUP_ENTITY_NAME,
_NODE_ENTITY_NAME,
_LOG_ENTITY_NAME,
_COMMENT_ENTITY_NAME,
):
if not data['export_data'].get(entity_type, {}):
continue
length = len(data['export_data'].get(entity_type, {}))
backend_cls = aiida_orm_to_backend[entity_type]
with get_progress_reporter()(desc=f'Adding {entity_type}s', total=length) as progress:
for nrows, rows in batch_iter(_iter_entity_fields(data, entity_type, node_repos), batch_size):
# to-do check for unused keys?
# to-do handle null values?
try:
connection.execute(insert(backend_cls.__table__), rows) # type: ignore
except IntegrityError as exc:
raise StorageMigrationError(f'Database integrity error: {exc}') from exc
progress.update(nrows)
if not (data['groups_uuid'] or data['links_uuid']):
return None
with engine.begin() as connection:
# get mapping of node IDs to node UUIDs
node_uuid_map = {
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbNode.uuid, v1_schema.DbNode.id))
}
# links
if data['links_uuid']:
def _transform_link(link_row):
try:
input_id = node_uuid_map[link_row['input']]
except KeyError:
raise StorageMigrationError(f'Database contains link with unknown input node: {link_row}')
try:
output_id = node_uuid_map[link_row['output']]
except KeyError:
raise StorageMigrationError(f'Database contains link with unknown output node: {link_row}')
return {
'input_id': input_id,
'output_id': output_id,
'label': link_row['label'],
'type': link_row['type'],
}
with get_progress_reporter()(desc='Adding Links', total=len(data['links_uuid'])) as progress:
for nrows, rows in batch_iter(data['links_uuid'], batch_size, transform=_transform_link):
connection.execute(insert(v1_schema.DbLink.__table__), rows)
progress.update(nrows)
# groups to nodes
if data['groups_uuid']:
# get mapping of node IDs to node UUIDs
group_uuid_map = {
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbGroup.uuid, v1_schema.DbGroup.id))
}
length = sum(len(uuids) for uuids in data['groups_uuid'].values())
unknown_nodes: Dict[str, set] = {}
with get_progress_reporter()(desc='Adding Group-Nodes', total=length) as progress:
for group_uuid, node_uuids in data['groups_uuid'].items():
group_id = group_uuid_map[group_uuid]
rows = []
for uuid in node_uuids:
if uuid in node_uuid_map:
rows.append({'dbnode_id': node_uuid_map[uuid], 'dbgroup_id': group_id})
else:
unknown_nodes.setdefault(group_uuid, set()).add(uuid)
connection.execute(insert(v1_schema.DbGroupNodes.__table__), rows)
progress.update(len(node_uuids))
if unknown_nodes:
MIGRATE_LOGGER.warning(f'Dropped unknown nodes in groups: {unknown_nodes}')
[docs]
def _convert_datetime(key, value):
if key in ('time', 'ctime', 'mtime') and value is not None:
return datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%f')
return value
[docs]
def _iter_entity_fields(
data,
name: str,
node_repos: Dict[str, List[Tuple[str, Optional[str]]]],
) -> Iterator[Dict[str, Any]]:
"""Iterate through entity fields."""
keys = file_fields_to_model_fields.get(name, {})
if name == _NODE_ENTITY_NAME:
# here we merge in the attributes and extras before yielding
attributes = data.get('node_attributes', {})
extras = data.get('node_extras', {})
for pk, all_fields in data['export_data'].get(name, {}).items():
if pk not in attributes:
raise CorruptStorage(f'Unable to find attributes info for Node with Pk={pk}')
if pk not in extras:
raise CorruptStorage(f'Unable to find extra info for Node with Pk={pk}')
uuid = all_fields['uuid']
repository_metadata = _create_repo_metadata(node_repos[uuid]) if uuid in node_repos else {}
yield {
**{keys.get(key, key): _convert_datetime(key, val) for key, val in all_fields.items()},
**{
'id': pk,
'attributes': attributes[pk],
'extras': extras[pk],
'repository_metadata': repository_metadata,
},
}
else:
for pk, all_fields in data['export_data'].get(name, {}).items():
yield {**{keys.get(key, key): _convert_datetime(key, val) for key, val in all_fields.items()}, **{'id': pk}}
[docs]
def _create_directory(top_level: File, path: PurePosixPath) -> File:
"""Create a new directory with the given path.
:param path: the relative path of the directory.
:return: the created directory.
"""
directory = top_level
for part in path.parts:
if part not in directory.objects:
directory.objects[part] = File(part)
directory = directory.objects[part]
return directory