###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# ruff: noqa: N802
"""This module contains the AiiDA backend ORM classes for the SQLite backend.
It re-uses the classes already defined in ``psql_dos`` backend (for PostGresQL),
but redefines the SQLAlchemy models to the SQLite compatible ones.
"""
import json
from functools import singledispatch
from typing import Any, List, Optional, Tuple, Union
from sqlalchemy import JSON, case, func
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import ColumnElement
from aiida.common.lang import type_check
from aiida.storage.psql_dos.orm import authinfos, comments, computers, entities, groups, logs, nodes, users, utils
from aiida.storage.psql_dos.orm.querybuilder.main import (
BinaryExpression,
Cast,
ColumnClause,
InstrumentedAttribute,
Label,
QueryableAttribute,
SqlaQueryBuilder,
String,
get_column,
)
from . import models
from .utils import ReadOnlyError
[docs]
class SqliteEntityOverride:
"""Overrides type-checking of psql_dos ``Entity``."""
MODEL_CLASS: Any
_model: utils.ModelWrapper
[docs]
@classmethod
def _class_check(cls):
"""Assert that the class is correctly configured"""
assert issubclass(
cls.MODEL_CLASS, models.SqliteBase
), 'Must set the MODEL_CLASS in the derived class to a SQLA model'
[docs]
@classmethod
def from_dbmodel(cls, dbmodel, backend):
"""Create an AiiDA Entity from the corresponding SQLA ORM model and storage backend
:param dbmodel: the SQLAlchemy model to create the entity from
:param backend: the corresponding storage backend
:return: the AiiDA entity
"""
cls._class_check()
type_check(dbmodel, cls.MODEL_CLASS)
entity = cls.__new__(cls)
super(entities.SqlaModelEntity, entity).__init__(backend) # type: ignore
entity._model = utils.ModelWrapper(dbmodel, backend)
return entity
[docs]
def store(self, *args, **kwargs):
backend = self._model._backend
if backend.read_only:
raise ReadOnlyError(f'Cannot store entity in read-only backend: {backend}')
return super().store(*args, **kwargs) # type: ignore
[docs]
class SqliteUser(SqliteEntityOverride, users.SqlaUser):
MODEL_CLASS = models.DbUser
[docs]
class SqliteUserCollection(users.SqlaUserCollection):
ENTITY_CLASS = SqliteUser
[docs]
class SqliteComputer(SqliteEntityOverride, computers.SqlaComputer):
MODEL_CLASS = models.DbComputer
[docs]
class SqliteComputerCollection(computers.SqlaComputerCollection):
ENTITY_CLASS = SqliteComputer
[docs]
class SqliteAuthInfo(SqliteEntityOverride, authinfos.SqlaAuthInfo):
MODEL_CLASS = models.DbAuthInfo
USER_CLASS = SqliteUser
COMPUTER_CLASS = SqliteComputer
[docs]
class SqliteAuthInfoCollection(authinfos.SqlaAuthInfoCollection):
ENTITY_CLASS = SqliteAuthInfo
[docs]
class SqliteGroup(SqliteEntityOverride, groups.SqlaGroup):
MODEL_CLASS = models.DbGroup
USER_CLASS = SqliteUser
[docs]
class SqliteGroupCollection(groups.SqlaGroupCollection):
ENTITY_CLASS = SqliteGroup
[docs]
class SqliteLog(SqliteEntityOverride, logs.SqlaLog):
MODEL_CLASS = models.DbLog
[docs]
class SqliteLogCollection(logs.SqlaLogCollection):
ENTITY_CLASS = SqliteLog
[docs]
class SqliteNode(SqliteEntityOverride, nodes.SqlaNode):
"""SQLA Node backend entity"""
MODEL_CLASS = models.DbNode
USER_CLASS = SqliteUser
COMPUTER_CLASS = SqliteComputer
LINK_CLASS = models.DbLink
[docs]
class SqliteNodeCollection(nodes.SqlaNodeCollection):
ENTITY_CLASS = SqliteNode
[docs]
class SqliteQueryBuilder(SqlaQueryBuilder):
"""QueryBuilder to use with SQLAlchemy-backend, adapted for SQLite."""
@property
def Node(self):
return models.DbNode
@property
def Link(self):
return models.DbLink
@property
def Computer(self):
return models.DbComputer
@property
def User(self):
return models.DbUser
@property
def Group(self):
return models.DbGroup
@property
def AuthInfo(self):
return models.DbAuthInfo
@property
def Comment(self):
return models.DbComment
@property
def Log(self):
return models.DbLog
@property
def table_groups_nodes(self):
return models.DbGroupNodes.__table__ # type: ignore[attr-defined]
[docs]
@staticmethod
def _get_projectable_entity(
alias: AliasedClass,
column_name: str,
attrpath: List[str],
cast: Optional[str] = None,
) -> Union[ColumnElement, InstrumentedAttribute]:
if not (attrpath or column_name in ('attributes', 'extras')):
return get_column(column_name, alias)
entity = get_column(column_name, alias)[attrpath]
if cast is None:
pass
elif cast == 'f':
entity = entity.as_float()
elif cast == 'i':
entity = entity.as_integer()
elif cast == 'b':
entity = entity.as_boolean()
elif cast == 't':
entity = entity.as_string()
elif cast == 'j':
entity = entity.as_json()
elif cast == 'd':
raise NotImplementedError('Date casting (d) for JSON key, not implemented for sqlite backend')
else:
raise ValueError(f'Unknown casting key {cast}')
return entity
[docs]
@staticmethod
def get_filter_expr_from_jsonb(
operator: str, value, attr_key: List[str], column=None, column_name=None, alias=None
):
"""Return a filter expression.
See: https://www.sqlite.org/json1.html
"""
if column is None:
column = get_column(column_name, alias)
query_str = f'{alias or ""}.{column_name or ""}.{attr_key} {operator} {value}'
def _cast_json_type(comparator: JSON.Comparator, value: Any) -> Tuple[ColumnElement, JSON.Comparator]:
"""Cast the JSON comparator to the target type."""
if isinstance(value, bool):
# SQLite booleans in JSON evaluate to 0/1, see:
# https://dba.stackexchange.com/questions/287377/how-can-i-set-a-json-value-to-a-boolean-in-sqlite
return func.json_type(comparator) == 'integer', comparator.as_boolean()
if isinstance(value, int):
return func.json_type(comparator).in_(['integer', 'real']), comparator.as_integer()
if isinstance(value, float):
return func.json_type(comparator).in_(['integer', 'real']), comparator.as_float()
if isinstance(value, str):
return func.json_type(comparator) == 'text', comparator.as_string()
if isinstance(value, list):
return func.json_type(comparator) == 'array', comparator.as_json()
if isinstance(value, dict):
return func.json_type(comparator) == 'object', comparator.as_json()
raise TypeError(f'Unsupported type {type(value)} for SQLite query: {query_str}')
database_entity: JSON.Comparator = column[tuple(attr_key)]
if operator == '==':
# to-do: non-existent keys also equate to json_type null, so should check it exists also
# if value is None:
# return func.json_type(database_entity) == 'null'
type_filter, casted_entity = _cast_json_type(database_entity, value)
if isinstance(value, (list, dict)):
return case((type_filter, casted_entity == func.json(json.dumps(value))), else_=False)
# to-do not working for dict
return case((type_filter, casted_entity == value), else_=False)
if operator == '>':
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity > value), else_=False)
if operator == '<':
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity < value), else_=False)
if operator in ('>=', '=>'):
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity >= value), else_=False)
if operator in ('<=', '=<'):
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity <= value), else_=False)
if operator == 'of_type':
# convert from postgres types http://www.postgresql.org/docs/9.5/static/functions-json.html
# for consistency with other backends
valid_types = ('object', 'array', 'string', 'number', 'boolean', 'null')
type_map = {'object': 'object', 'array': 'array', 'string': 'text', 'null': 'null'}
if value in type_map:
return func.json_type(database_entity) == type_map[value]
if value == 'boolean':
type_filter = func.json_type(database_entity) == 'integer'
value_filter = database_entity.as_boolean().in_([True, False])
return case((type_filter, value_filter <= value), else_=False)
if value == 'number':
return func.json_type(database_entity).in_(['integer', 'real'])
raise ValueError(f'value {value!r} for `of_type` is not among valid types: {valid_types}')
if operator == 'like':
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity.like(value, escape='\\')), else_=False)
if operator == 'ilike':
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity.ilike(value, escape='\\')), else_=False)
# if operator == 'contains':
# to-do, see: https://github.com/sqlalchemy/sqlalchemy/discussions/7836
if operator == 'has_key':
return case(
(
func.json_type(database_entity) == 'object',
func.json_each(database_entity).table_valued('key', joins_implicitly=True).c.key == value,
),
else_=False,
)
if operator == 'in':
type_filter, casted_entity = _cast_json_type(database_entity, value[0])
return case((type_filter, casted_entity.in_(value)), else_=False)
if operator == 'of_length':
return case(
(
func.json_type(database_entity) == 'array',
func.json_array_length(database_entity.as_json()) == value,
),
else_=False,
)
if operator == 'longer':
return case(
(
func.json_type(database_entity) == 'array',
func.json_array_length(database_entity.as_json()) > value,
),
else_=False,
)
if operator == 'shorter':
return case(
(
func.json_type(database_entity) == 'array',
func.json_array_length(database_entity.as_json()) < value,
),
else_=False,
)
raise ValueError(f'SQLite does not support JSON query: {query_str}')
[docs]
@staticmethod
def get_filter_expr_from_column(operator: str, value: Any, column) -> BinaryExpression:
# Label is used because it is what is returned for the
# 'state' column by the hybrid_column construct
if not isinstance(column, (Cast, InstrumentedAttribute, QueryableAttribute, Label, ColumnClause)):
raise TypeError(f'column ({type(column)}) {column} is not a valid column')
database_entity = column
if operator == '==':
expr = database_entity == value
elif operator == '>':
expr = database_entity > value
elif operator == '<':
expr = database_entity < value
elif operator == '>=':
expr = database_entity >= value
elif operator == '<=':
expr = database_entity <= value
elif operator == 'like':
# the like operator expects a string, so we cast to avoid problems
# with fields like UUID, which don't support the like operator
expr = database_entity.cast(String).like(value, escape='\\')
elif operator == 'ilike':
expr = database_entity.ilike(value, escape='\\')
elif operator == 'in':
expr = database_entity.in_(value)
else:
raise ValueError(f'Unknown operator {operator} for filters on columns')
return expr
[docs]
@singledispatch
def get_backend_entity(dbmodel, backend):
raise TypeError(f"No corresponding AiiDA backend class exists for the model class '{dbmodel.__class__.__name__}'")
@get_backend_entity.register(models.DbUser) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteUser.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbGroup) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteGroup.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbComputer) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteComputer.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbNode) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteNode.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbAuthInfo) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteAuthInfo.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbComment) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteComment.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbLog) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteLog.from_dbmodel(dbmodel, backend)
[docs]
@get_backend_entity.register(models.DbLink) # type: ignore[call-overload]
def _(dbmodel, backend):
from aiida.orm.utils.links import LinkQuadruple
return LinkQuadruple(dbmodel.input_id, dbmodel.output_id, dbmodel.type, dbmodel.label)