Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre/post processors to allow nan and inf values to be stored in JSON #13

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions aiida/orm/nodes/data/msonable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
"""Data plugin for classes that implement the ``MSONable`` class of the ``monty`` library."""
import importlib
import json

from aiida.orm import Data


class MsonableData(Data):
"""Data plugin that allows to easily wrap objects that are MSONable.

To use this class, simply construct it passing an isntance of any ``MSONable`` class and store it, for example:

from pymatgen.core import Molecule

molecule = Molecule(['H']. [0, 0, 0])
node = MsonableData(molecule)
node.store()

After storing, the node can be loaded like any other node and the original MSONable instance can be retrieved:

loaded = load_node(node.pk)
molecule = loaded.obj

.. note:: As the ``MSONable`` mixin class requires, the wrapped object needs to implement the methods ``as_dict``
and ``from_dict``. A default implementation should be present on the ``MSONable`` base class itself, but it
might need to be overridden in a specific implementation.

"""

def __init__(self, obj, *args, **kwargs):
"""Construct the node from the pymatgen object."""
if obj is None:
raise TypeError('the `obj` argument cannot be `None`.')

for method in ['as_dict', 'from_dict']:
if not hasattr(obj, method) or not callable(getattr(obj, method)):
raise TypeError(f'the `obj` argument does not have the required `{method}` method.')

super().__init__(*args, **kwargs)

self._obj = obj

# Serialize the object by calling ``as_dict`` and performing a roundtrip through JSON encoding.
# This relies on obj.as_dict() giving JSON serializable outputs. The round trip is necessary for
# constants NaN, inf, -inf which are serialised (by JSONEncoder) as plain strings and kept during
# the deserialization.
serialized = json.loads(json.dumps(obj.as_dict()), parse_constant=lambda x: x)

# Then we apply our own custom serializer that serializes the float constants infinity and nan to a string value
# which is necessary because the serializer of the ``json`` standard module deserializes to the Python values
# that can not be written to JSON.
self.set_attribute_many(serialized)

@classmethod
def _deserialize_float_constants(cls, data):
"""Deserialize the contents of a dictionary ``data`` deserializing infinity and NaN string constants.

The ``data`` dictionary is recursively checked for the ``Infinity``, ``-Infinity`` and ``NaN`` strings, which
are the Javascript string equivalents to the Python ``float('inf')``, ``-float('inf')`` and ``float('nan')``
float constants. If one of the strings is encountered, the Python float constant is returned and otherwise the
original value is returned.
"""
if isinstance(data, dict):
return {k: cls._deserialize_float_constants(v) for k, v in data.items()}
if isinstance(data, list):
return [cls._deserialize_float_constants(v) for v in data]
if data == 'Infinity':
return float('inf')
if data == '-Infinity':
return -float('inf')
if data == 'NaN':
return float('nan')
return data

def _get_object(self):
"""Return the cached wrapped MSONable object.

.. note:: If the object is not yet present in memory, for example if the node was loaded from the database,
the object will first be reconstructed from the state stored in the node attributes.

"""
try:
return self._obj
except AttributeError:
attributes = self.attributes
class_name = attributes['@class']
module_name = attributes['@module']

try:
module = importlib.import_module(module_name)
except ImportError as exc:
raise ImportError(f'the objects module `{module_name}` can not be imported.') from exc

try:
cls = getattr(module, class_name)
except AttributeError as exc:
raise ImportError(
f'the objects module `{module_name}` does not contain the class `{class_name}`.'
) from exc

# First we need to deserialize any infinity or nan float string markers that were serialized in the
# constructor of this node when it was created. There the decoding step in the JSON roundtrip defined a
# pass-through for the ``parse_constant`` argument, which means that the serialized versions of the float
# constants (i.e. the strings ``Infinity`` etc.) are not deserialized in the Python float constants. Here we
# need to first explicit deserialize them. One would think that we could simply let the ``json.loads`` in
# the following step take care of this, however, since the attributes would first be serialized by the
# ``json.dumps`` call, the string placeholders would be dumped again to an actual string, which would then
# no longer be recognized by ``json.loads`` as the Javascript notation of the float constants and so it will
# leave them as separate strings.
deserialized = self._deserialize_float_constants(attributes)

# Finally, reconstruct the original ``MSONable`` class from the fully deserialized data.
self._obj = cls.from_dict(deserialized)

return self._obj

@property
def obj(self):
"""Return the wrapped MSONable object."""
return self._get_object()
1 change: 1 addition & 0 deletions setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
"folder = aiida.orm.nodes.data.folder:FolderData",
"int = aiida.orm.nodes.data.int:Int",
"list = aiida.orm.nodes.data.list:List",
"msonable = aiida.orm.nodes.data.msonable:MsonableData",
"numeric = aiida.orm.nodes.data.numeric:NumericType",
"orbital = aiida.orm.nodes.data.orbital:OrbitalData",
"remote = aiida.orm.nodes.data.remote.base:RemoteData",
Expand Down
227 changes: 227 additions & 0 deletions tests/orm/data/test_msonable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# -*- coding: utf-8 -*-
"""Tests for the :class:`aiida.orm.nodes.data.msonable.MsonableData` data type."""
import datetime
import math
from json import loads, dumps

from monty.json import MSONable, MontyEncoder
import pymatgen
import pytest
import numpy

from aiida.orm import load_node
from aiida.orm.nodes.data.msonable import MsonableData


class MsonableClass(MSONable):
"""Dummy class that implements the ``MSONable interface``."""

def __init__(self, data):
"""Construct a new object."""
self._data = data

@property
def data(self):
"""Return the data of this instance."""
return self._data

def as_dict(self):
"""Represent the object as a JSON-serializable dictionary."""
return {
'@module': self.__class__.__module__,
'@class': self.__class__.__name__,
'data': self._data,
}

@classmethod
def from_dict(cls, d):
"""Reconstruct an instance from a serialized version."""
return cls(data=d['data'])


class MsonableClass2(MSONable):
"""Dummy class that implements the ``MSONable interface``."""

def __init__(self, obj, array, timestamp=None):
"""Construct a new object."""
self._obj = obj
self._array = array
if timestamp is None:
self._timestamp = datetime.datetime.now()
else:
self._timestamp = timestamp

@property
def obj(self):
"""Return the data of this instance."""
return self._obj

@property
def array(self):
"""Return the data of this instance."""
return self._array

@property
def timestamp(self):
"""Return the timestamp"""
return self._timestamp

def as_dict(self):
"""Represent the object as a JSON-serializable dictionary."""
return_dict = {
'@module': self.__class__.__module__,
'@class': self.__class__.__name__,
'obj': self.obj.as_dict(),
'timestamp': loads(dumps(self.timestamp, cls=MontyEncoder)),
'array': loads(dumps(self._array, cls=MontyEncoder))
}
return return_dict


def test_construct():
"""Test the ``MsonableData`` constructor."""
data = {'a': 1}
obj = MsonableClass(data)
node = MsonableData(obj)

assert isinstance(node, MsonableData)
assert not node.is_stored


def test_constructor_object_none():
"""Test the ``MsonableData`` constructor raises if object is ``None``."""
with pytest.raises(TypeError, match=r'the `obj` argument cannot be `None`.'):
MsonableData(None)


def test_invalid_class_no_as_dict():
"""Test the ``MsonableData`` constructor raises if object does not sublass ``MSONable``."""

class InvalidClass(MSONable):

@classmethod
def from_dict(cls, d):
pass

# Remove the ``as_dict`` method from the ``MSONable`` base class because that is currently implemented by default.
del MSONable.as_dict

with pytest.raises(TypeError, match=r'the `obj` argument does not have the required `as_dict` method.'):
MsonableData(InvalidClass())


@pytest.mark.usefixtures('clear_database_before_test')
def test_store():
"""Test storing a ``MsonableData`` instance."""
data = {'a': 1}
obj = MsonableClass(data)
node = MsonableData(obj)
assert not node.is_stored

node.store()
assert node.is_stored


@pytest.mark.usefixtures('clear_database_before_test')
def test_load():
"""Test loading a ``MsonableData`` instance."""
data = {'a': 1}
obj = MsonableClass(data)
node = MsonableData(obj)
node.store()

loaded = load_node(node.pk)
assert isinstance(node, MsonableData)
assert loaded == node


@pytest.mark.usefixtures('clear_database_before_test')
def test_obj():
"""Test the ``MsonableData.obj`` property."""
data = [1, float('inf'), float('-inf'), float('nan')]
obj = MsonableClass(data)
node = MsonableData(obj)
node.store()

assert isinstance(node.obj, MsonableClass)
assert node.obj.data == data

loaded = load_node(node.pk)
assert isinstance(node.obj, MsonableClass)

for left, right in zip(loaded.obj.data, data):

# Need this explicit case to compare NaN because of the peculiarity in Python where ``float(nan) != float(nan)``
if isinstance(left, float) and math.isnan(left):
assert math.isnan(right)
continue

try:
# This is needed to match numpy arrays
assert (left == right).all()
except AttributeError:
assert left == right


@pytest.mark.usefixtures('clear_database_before_test')
def test_complex_obj():
"""Test the ``MsonableData.obj`` property for a more complex class."""
data = [1, float('inf'), float('-inf'), float('nan')]
obj = MsonableClass(data)
obj2 = MsonableClass2(obj=obj, array=numpy.arange(10))
node = MsonableData(obj2)
node.store()

assert isinstance(node.obj, MsonableClass2)
assert node.obj.obj.data == data

loaded = load_node(node.pk)
assert isinstance(node.obj, MsonableClass2)

for left, right in zip(loaded.obj.obj.data, data):

# Need this explicit case to compare NaN because of the peculiarity in Python where ``float(nan) != float(nan)``
if isinstance(left, float) and math.isnan(left):
assert math.isnan(right)
continue

try:
# This is needed to match numpy arrays
assert (left == right).all()
except AttributeError:
assert left == right

assert isinstance(loaded.obj.timestamp, datetime.datetime)
numpy.testing.assert_allclose(loaded.obj.array, numpy.arange(10))


@pytest.mark.usefixtures('clear_database_before_test')
def test_unimportable_module():
"""Test the ``MsonableData.obj`` property if the associated module cannot be loaded."""
obj = pymatgen.core.Molecule(['H'], [[0, 0, 0]])
node = MsonableData(obj)

# Artificially change the ``@module`` in the attributes so it becomes unloadable
node.set_attribute('@module', 'not.existing')
node.store()

loaded = load_node(node.pk)

with pytest.raises(ImportError, match='the objects module `not.existing` can not be imported.'):
_ = loaded.obj


@pytest.mark.usefixtures('clear_database_before_test')
def test_unimportable_class():
"""Test the ``MsonableData.obj`` property if the associated class cannot be loaded."""
obj = pymatgen.core.Molecule(['H'], [[0, 0, 0]])
node = MsonableData(obj)

# Artificially change the ``@class`` in the attributes so it becomes unloadable
node.set_attribute('@class', 'NonExistingClass')
node.store()

loaded = load_node(node.pk)

with pytest.raises(ImportError, match=r'the objects module `.*` does not contain the class `NonExistingClass`.'):
_ = loaded.obj