# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
from __future__ import absolute_import
from sqlalchemy import orm
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.exc import UnmappedClassError
import aiida.backends.sqlalchemy
from aiida.common.exceptions import InvalidOperation
# Taken from
# https://github.com/mitsuhiko/flask-sqlalchemy/blob/master/flask_sqlalchemy/__init__.py#L491
[docs]class _QueryProperty(object):
[docs] def __init__(self, query_class=orm.Query):
self.query_class = query_class
[docs] def __get__(self, obj, _type):
try:
mapper = orm.class_mapper(_type)
if mapper:
return self.query_class(
mapper, session=aiida.backends.sqlalchemy.get_scoped_session())
except UnmappedClassError:
return None
[docs]class _SessionProperty(object):
[docs] def __get__(self, obj, _type):
if not aiida.backends.sqlalchemy.get_scoped_session():
raise InvalidOperation("You need to call load_dbenv before "
"accessing the session of SQLALchemy.")
return aiida.backends.sqlalchemy.get_scoped_session()
[docs]class _AiidaQuery(orm.Query):
[docs] def __init__(self, *args, **kwargs):
"""Constructor"""
super(_AiidaQuery, self).__init__(*args, **kwargs)
[docs] def __iter__(self):
iterator = super(_AiidaQuery, self).__iter__()
for r in iterator:
# Allow the use of with_entities
if issubclass(type(r), Model):
yield r.get_aiida_class()
else:
yield r
from aiida.backends.sqlalchemy import get_scoped_session
[docs]class Model(object):
query = _QueryProperty()
session = _SessionProperty()
[docs] def save(self, commit=True):
sess = get_scoped_session()
sess.add(self)
if commit:
sess.commit()
return self
[docs] def delete(self, commit=True):
sess = get_scoped_session()
sess.delete(self)
if commit:
sess.commit()
Base = declarative_base(cls=Model, name='Model')