diff --git a/app/src/App.tsx b/app/src/App.tsx index 5ecd1e3b2..bb394883b 100644 --- a/app/src/App.tsx +++ b/app/src/App.tsx @@ -134,17 +134,15 @@ export default function App() { // if step changes useEffect(() => { socket.emit("room:frames:get", [step], (frames: Frames) => { - // map positions: numbers[][] to THREE.Vector3[] - for (const key in frames) { if (frames.hasOwnProperty(key)) { - const frame = frames[key]; + const frame: Frame = frames[key]["value"]; frame.positions = frame.positions.map( (position) => new THREE.Vector3(position[0], position[1], position[2]), ) as THREE.Vector3[]; } - setCurrentFrame(frames[step]); + setCurrentFrame(frames[step]["value"]); setNeedsUpdate(false); // rename this to something more descriptive } }); diff --git a/app/src/components/particles.tsx b/app/src/components/particles.tsx index fbdd57eb8..27cbb328d 100644 --- a/app/src/components/particles.tsx +++ b/app/src/components/particles.tsx @@ -17,7 +17,7 @@ export interface Frame { ``; export interface Frames { - [key: number]: Frame; + [key: number]: { _type: string; value: Frame }; } export const Player = ({ diff --git a/poetry.lock b/poetry.lock index 8396630fc..04db9c825 100644 --- a/poetry.lock +++ b/poetry.lock @@ -284,7 +284,7 @@ files = [ name = "attrs" version = "23.2.0" description = "Classes Without Boilerplate" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, @@ -4790,33 +4790,15 @@ networkx = ">=3.0,<4.0" [package.extras] dask = ["bokeh (>=2.4.2,<3.0.0)", "dask (>=2022.12.1,<2023.0.0)", "dask-jobqueue (>=0.8.1,<0.9.0)", "distributed (>=2022.12.1,<2023.0.0)"] -[[package]] -name = "znframe" -version = "0.1.5" -description = "ZnFrame - ASE-like Interface based on dataclasses" -optional = false -python-versions = ">=3.9,<4.0" -files = [ - {file = "znframe-0.1.5-py3-none-any.whl", hash = "sha256:3645ee836b19f3d957686ba58f2fd4c142afef44df33764856848579e3af61f4"}, - {file = "znframe-0.1.5.tar.gz", hash = "sha256:b5424e196af4df7fc8662c8ce6be5950a1307b763d4159bad69a0559448f6acf"}, -] - -[package.dependencies] -ase = ">=3.22.1,<4.0.0" -attrs = ">=23.1.0,<24.0.0" -networkx = ">=3.2.1,<4.0.0" -numpy = ">=1.26.2,<2.0.0" -pydantic = ">=2.5.3,<3.0.0" - [[package]] name = "znh5md" -version = "0.2.0" +version = "0.2.1" description = "High Performance Interface for H5MD Trajectories" optional = true python-versions = "<4.0,>=3.9" files = [ - {file = "znh5md-0.2.0-py3-none-any.whl", hash = "sha256:563e5c9b6a1e29864d31c4731416eb4b79f7acfee23cbe8cc53730b247d1125b"}, - {file = "znh5md-0.2.0.tar.gz", hash = "sha256:da60ca2e53f2d74cdc3a5954dbadae912973e4a29456329d2e6406f8d5e794bf"}, + {file = "znh5md-0.2.1-py3-none-any.whl", hash = "sha256:73a4393544ed791d6aaa3850a30e14dd4f04e820c3cf5ea5f0aa74eab1be52d0"}, + {file = "znh5md-0.2.1.tar.gz", hash = "sha256:aa10de0219f0abc0ab685c1e3d91b556ea9e03cca8e08c8d491e50dcd511cb0b"}, ] [package.dependencies] @@ -4903,4 +4885,4 @@ rdkit = ["rdkit2ase"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "018f62b93cd7caba2762e1242d2ac48f03f403cd17bbdac6279f2d151ec4ccca" +content-hash = "76d89fb129c3dfe3da2cc64d4f88c05e1c1722008287e4dafa40207f16aa06f6" diff --git a/pyproject.toml b/pyproject.toml index 4f49cb2e6..5fcd83f19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,13 +25,13 @@ mdanalysis = {version = "^2", optional = true} tidynamics = {version = "^1", optional = true} rdkit2ase = {version = "^0.1", optional = true} zntrack = {version = "^0.7.3", optional = true} -znframe = "^0.1" celery = "^5" sqlalchemy = "^2" redis = "^5" splines = "^0.3" znsocket = "^0.1.6" lazy-loader = "^0.4" +znjson = "^0.2.3" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_serializer.py b/tests/test_serializer.py new file mode 100644 index 000000000..25681ca0c --- /dev/null +++ b/tests/test_serializer.py @@ -0,0 +1,44 @@ +import ase +import numpy.testing as npt +import pytest +import znjson +from ase.calculators.singlepoint import SinglePointCalculator + +from zndraw.utils import ASEConverter + + +def test_ase_converter(s22): + s22[0].connectivity = [[0, 1, 1], [1, 2, 1], [2, 3, 1]] + s22[3].calc = SinglePointCalculator(s22[3]) + s22[3].calc.results = {"energy": 0.0, "predicted_energy": 1.0} + s22[4].info = {"key": "value"} + + structures_json = znjson.dumps( + s22, cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ) + structures = znjson.loads( + structures_json, cls=znjson.ZnDecoder.from_converters([ASEConverter]) + ) + for s1, s2 in zip(s22, structures): + assert s1 == s2 + + npt.assert_array_equal(structures[0].connectivity, [[0, 1, 1], [1, 2, 1], [2, 3, 1]]) + with pytest.raises(AttributeError): + _ = structures[1].connectivity + + assert structures[3].calc.results == {"energy": 0.0, "predicted_energy": 1.0} + + assert "colors" in structures[0].arrays + assert "radii" in structures[0].arrays + + assert structures[4].info == {"key": "value"} + + +def test_exotic_atoms(): + atoms = ase.Atoms("X", positions=[[0, 0, 0]]) + new_atoms = znjson.loads( + znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])), + cls=znjson.ZnDecoder.from_converters([ASEConverter]), + ) + npt.assert_array_equal(new_atoms.arrays["colors"], ["#ff0000"]) + npt.assert_array_equal(new_atoms.arrays["radii"], [0.3]) diff --git a/zndraw/base.py b/zndraw/base.py index 9bdb992fb..4f9194bfd 100644 --- a/zndraw/base.py +++ b/zndraw/base.py @@ -7,12 +7,14 @@ import ase import numpy as np import splines -import znframe +import znjson import znsocket from flask import current_app, session from pydantic import BaseModel, Field, create_model from redis import Redis +from zndraw.utils import ASEConverter + log = logging.getLogger(__name__) @@ -42,7 +44,9 @@ def get_atoms() -> ase.Atoms: lst = znsocket.List(r, key) try: frame_json = lst[int(step)] - return znframe.Frame.from_json(frame_json).to_atoms() + return znjson.loads( + frame_json, cls=znjson.ZnDecoder.from_converters([ASEConverter]) + ) except TypeError: # step is None return ase.Atoms() diff --git a/zndraw/bonds/__init__.py b/zndraw/bonds/__init__.py index e5aff3d85..6dfc5e9b2 100644 --- a/zndraw/bonds/__init__.py +++ b/zndraw/bonds/__init__.py @@ -47,8 +47,9 @@ def remove_edge(graph, atom_1, atom_2): except NetworkXError: pass - def get_bonds(self, atoms: ase.Atoms): - graph = atoms.connectivity + def get_bonds(self, atoms: ase.Atoms, graph: nx.Graph = None): + if graph is None: + graph = self.build_graph(atoms) bonds = [] for edge in graph.edges: bonds.append((edge[0], edge[1], graph.edges[edge]["weight"])) diff --git a/zndraw/modify/__init__.py b/zndraw/modify/__init__.py index 5ca95c450..65bbf9250 100644 --- a/zndraw/modify/__init__.py +++ b/zndraw/modify/__init__.py @@ -8,7 +8,6 @@ import numpy as np from ase.data import chemical_symbols from pydantic import Field -from znframe.frame import get_radius from zndraw.base import Extension, MethodsCollection @@ -50,13 +49,13 @@ class Connect(UpdateScene): """Create guiding curve between selected atoms.""" def run(self, vis: "ZnDraw", **kwargs) -> None: + atoms = vis.atoms atom_ids = vis.selection atom_positions = vis.atoms.get_positions() - atom_numbers = vis.atoms.numbers[atom_ids] camera_position = np.array(vis.camera["position"])[None, :] # 1,3 new_points = atom_positions[atom_ids] # N, 3 - radii: np.ndarray = get_radius(atom_numbers)[0][:, None] # N, 1 + radii: np.ndarray = atoms.arrays["radii"][atom_ids][:, None] direction = camera_position - new_points direction /= np.linalg.norm(direction, axis=1, keepdims=True) new_points += direction * radii @@ -118,7 +117,8 @@ def run(self, vis: "ZnDraw", **kwargs) -> None: else: for idx, atom_id in enumerate(sorted(atom_ids)): atoms.pop(atom_id - idx) # we remove the atom and shift the index - del atoms.connectivity + if hasattr(atoms, "connectivity"): + del atoms.connectivity vis.append(atoms) vis.selection = [] vis.step += 1 diff --git a/zndraw/scene.py b/zndraw/scene.py index b4e5bd164..9523ffa3d 100644 --- a/zndraw/scene.py +++ b/zndraw/scene.py @@ -1,12 +1,14 @@ import enum import ase -import znframe +import znjson import znsocket from flask import current_app, session from pydantic import BaseModel, Field from redis import Redis +from zndraw.utils import ASEConverter + class Material(str, enum.Enum): MeshBasicMaterial = "MeshBasicMaterial" @@ -74,7 +76,9 @@ def _get_atoms() -> ase.Atoms: lst = znsocket.List(r, key) try: frame_json = lst[int(step)] - return znframe.Frame.from_json(frame_json).to_atoms() + return znjson.loads( + frame_json, cls=znjson.ZnDecoder.from_converters([ASEConverter]) + ) except TypeError: # step is None return ase.Atoms() diff --git a/zndraw/tasks/__init__.py b/zndraw/tasks/__init__.py index 7e13bb1b4..7fef23fdc 100644 --- a/zndraw/tasks/__init__.py +++ b/zndraw/tasks/__init__.py @@ -4,12 +4,14 @@ import ase.io import socketio.exceptions import tqdm -import znframe +import znjson import znsocket from celery import shared_task from flask import current_app from zndraw.base import FileIO +from zndraw.bonds import ASEComputeBonds +from zndraw.utils import ASEConverter log = logging.getLogger(__name__) @@ -37,6 +39,7 @@ def read_file(fileio: dict) -> None: r.delete("room:default:frames") lst = znsocket.List(r, "room:default:frames") + bonds_calculator = ASEComputeBonds() if file_io.name is None: @@ -77,8 +80,11 @@ def _generator(): break if file_io.step and idx % file_io.step != 0: continue - frame = znframe.Frame.from_atoms(atoms) - lst.append(frame.to_json()) + if not hasattr(atoms, "connectivity"): + atoms.connectivity = bonds_calculator.get_bonds(atoms) + lst.append( + znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])) + ) if idx == 0: try: io.connect(current_app.config["SERVER_URL"], wait_timeout=10) diff --git a/zndraw/utils.py b/zndraw/utils.py index 0001c6092..ecf680db3 100644 --- a/zndraw/utils.py +++ b/zndraw/utils.py @@ -1,3 +1,4 @@ +import functools import importlib.util import json import logging @@ -8,12 +9,159 @@ import typing as t import uuid +import ase import datamodel_code_generator +import numpy as np import socketio.exceptions +from ase.calculators.singlepoint import SinglePointCalculator +from ase.data import covalent_radii +from ase.data.colors import jmol_colors +from znjson import ConverterBase log = logging.getLogger(__name__) +class ASEDict(t.TypedDict): + numbers: list[int] + positions: list[list[float]] + connectivity: list[tuple[int, int, int]] + arrays: dict[str, list[float | int | list[float | int]]] + info: dict[str, float | int] + # calc: dict[str, float|int|np.ndarray] # should this be split into arrays and info? + pbc: list[bool] + cell: list[list[float]] + + +def rgb2hex(value): + r, g, b = np.array(value * 255, dtype=int) + return "#%02x%02x%02x" % (r, g, b) + + +@functools.lru_cache(maxsize=128) +def get_scaled_radii() -> np.ndarray: + """Scale down the covalent radii to visualize bonds better.""" + radii = covalent_radii + # shift the values such that they are in [0.3, 1.3] + radii = radii - np.min(radii) + radii = radii / np.max(radii) + radii = radii + 0.3 + return radii + + +class ASEConverter(ConverterBase): + """Encode/Decode datetime objects + + Attributes + ---------- + level: int + Priority of this converter over others. + A higher level will be used first, if there + are multiple converters available + representation: str + An unique identifier for this converter. + instance: + Used to select the correct converter. + This should fulfill isinstance(other, self.instance) + or __eq__ should be overwritten. + """ + + level = 100 + representation = "ase.Atoms" + instance = ase.Atoms + + def encode(self, obj: ase.Atoms) -> ASEDict: + """Convert the datetime object to str / isoformat""" + + numbers = obj.numbers.tolist() + positions = obj.positions.tolist() + pbc = obj.pbc.tolist() + cell = obj.cell.tolist() + + info = { + k: v + for k, v in obj.info.items() + if isinstance(v, (float, int, str, bool, list)) + } + info |= {k: v.tolist() for k, v in obj.info.items() if isinstance(v, np.ndarray)} + + if obj.calc is not None: + calc = { + k: v + for k, v in obj.calc.results.items() + if isinstance(v, (float, int, str, bool, list)) + } + calc |= { + k: v.tolist() + for k, v in obj.calc.results.items() + if isinstance(v, np.ndarray) + } + else: + calc = {} + + # All additional information should be stored in calc.results + # and not in calc.arrays, thus we will not convert it here! + arrays = {} + if "colors" not in obj.arrays: + arrays["colors"] = [rgb2hex(jmol_colors[number]) for number in numbers] + else: + arrays["colors"] = ( + obj.arrays["colors"].tolist() + if isinstance(obj.arrays["colors"], np.ndarray) + else obj.arrays["colors"] + ) + + if "radii" not in obj.arrays: + # arrays["radii"] = [covalent_radii[number] for number in numbers] + arrays["radii"] = [get_scaled_radii()[number] for number in numbers] + else: + arrays["radii"] = ( + obj.arrays["radii"].tolist() + if isinstance(obj.arrays["radii"], np.ndarray) + else obj.arrays["radii"] + ) + + if hasattr(obj, "connectivity") and obj.connectivity is not None: + connectivity = ( + obj.connectivity.tolist() + if isinstance(obj.connectivity, np.ndarray) + else obj.connectivity + ) + else: + connectivity = [] + + return ASEDict( + numbers=numbers, + positions=positions, + connectivity=connectivity, + arrays=arrays, + info=info, + calc=calc, + pbc=pbc, + cell=cell, + ) + + def decode(self, value: ASEDict) -> ase.Atoms: + """Create datetime object from str / isoformat""" + atoms = ase.Atoms( + numbers=value["numbers"], + positions=value["positions"], + info=value["info"], + pbc=value["pbc"], + cell=value["cell"], + ) + if connectivity := value.get("connectivity"): + # or do we want this to be nx.Graph? + atoms.connectivity = np.array(connectivity) + if "colors" in value["arrays"]: + atoms.arrays["colors"] = np.array(value["arrays"]["colors"]) + if "radii" in value["arrays"]: + atoms.arrays["radii"] = np.array(value["arrays"]["radii"]) + if calc := value.get("calc"): + atoms.calc = SinglePointCalculator(atoms) + atoms.calc.results.update(calc) + return atoms + + def get_port(default: int) -> int: """Get an open port.""" try: @@ -78,52 +226,6 @@ def get_cls_from_json_schema(schema: dict, name: str, **kwargs): return getattr(module, name) -def ensure_path(path: str): - """Ensure that a path exists.""" - p = pathlib.Path(path).expanduser() - p.mkdir(parents=True, exist_ok=True) - return p.as_posix() - - -def wrap_and_check_index(index: int | slice | list[int], length: int) -> list[int]: - is_slice = isinstance(index, slice) - if is_slice: - index = list(range(*index.indices(length))) - index = [index] if isinstance(index, int) else index - index = [i if i >= 0 else length + i for i in index] - # check if index is out of range - for i in index: - if i >= length: - raise IndexError(f"Index {i} out of range for length {length}") - if i < 0: - raise IndexError(f"Index {i-length} out of range for length {length}") - return index - - -def check_selection(value: list[int], maximum: int): - """Check if the selection is valid - - Attributes - ---------- - value: list[int] - the selected indices - maximum: int - len(vis.step), will be incremented by one, to account for - """ - if not isinstance(value, list): - raise ValueError("Selection must be a list") - if any(not isinstance(i, int) for i in value): - raise ValueError("Selection must be a list of integers") - if len(value) != len(set(value)): - raise ValueError("Selection must be unique") - if any(i < 0 for i in value): - raise ValueError("Selection must be positive integers") - if any(i >= maximum for i in value): - raise ValueError( - f"Can not select particles indices larger than size of the scene: {maximum }. Got {value}" - ) - - def emit_with_retry( socket: socketio.Client, event, diff --git a/zndraw/zndraw.py b/zndraw/zndraw.py index 98c670570..529f29c28 100644 --- a/zndraw/zndraw.py +++ b/zndraw/zndraw.py @@ -8,13 +8,14 @@ import numpy as np import socketio.exceptions import tqdm -import znframe +import znjson import znsocket from redis import Redis from zndraw.base import Extension, ZnDrawBase +from zndraw.bonds import ASEComputeBonds from zndraw.draw import Geometry, Object3D -from zndraw.utils import call_with_retry, emit_with_retry +from zndraw.utils import ASEConverter, call_with_retry, emit_with_retry log = logging.getLogger(__name__) @@ -86,6 +87,10 @@ class ZnDraw(ZnDrawBase): default_factory=datetime.datetime.now ) + bond_calculator: ASEComputeBonds | None = dataclasses.field( + default_factory=ASEComputeBonds, repr=False + ) + def __post_init__(self): def on_wakeup(): if self._available: @@ -154,7 +159,12 @@ def __getitem__(self, index) -> ase.Atoms | list[ase.Atoms]: retries=self.timeout["call_retries"], ) - structures = [znframe.Frame(**x).to_atoms() for x in data.values()] + structures = [ + znjson.loads( + json.dumps(x), cls=znjson.ZnDecoder.from_converters([ASEConverter]) + ) + for x in data.values() + ] return structures[0] if single_item else structures def __setitem__( @@ -163,11 +173,16 @@ def __setitem__( if isinstance(index, slice): index = list(range(*index.indices(len(self)))) if isinstance(index, int): - data = {index: znframe.Frame.from_atoms(value).to_json()} - else: - data = { - i: znframe.Frame.from_atoms(val).to_json() for i, val in zip(index, value) - } + index = [index] + value = [value] + + data = {} + for i, val in zip(index, value): + if not hasattr(val, "connectivity") and self.bond_calculator is not None: + val.connectivity = self.bond_calculator.get_bonds(val) + data[i] = znjson.dumps( + val, cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ) call_with_retry( self.socket, @@ -199,10 +214,18 @@ def __delitem__(self, index: int | slice | list[int]): ) def insert(self, index: int, value: ase.Atoms): + if hasattr(value, "connectivity") and self.bond_calculator is not None: + value.connectivity = self.bond_calculator.get_bonds(value) + call_with_retry( self.socket, "room:frames:insert", - {"index": index, "value": znframe.Frame.from_atoms(value).to_json()}, + { + "index": index, + "value": znjson.dumps( + value, cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ), + }, retries=self.timeout["call_retries"], ) @@ -211,10 +234,13 @@ def extend(self, values: list[ase.Atoms]): # enable tbar if more than 10 messages are sent # approximated by the size of the first frame + show_tbar = ( len(values) * len( - json.dumps(znframe.Frame.from_atoms(values[0]).to_json()).encode("utf-8") + znjson.dumps( + values[0], cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ).encode("utf-8") ) ) > (10 * self.maximum_message_size) tbar = tqdm.tqdm( @@ -222,7 +248,12 @@ def extend(self, values: list[ase.Atoms]): ) for i, val in enumerate(tbar, start=len(self)): - msg[i] = znframe.Frame.from_atoms(val).to_json() + if not hasattr(val, "connectivity") and self.bond_calculator is not None: + val.connectivity = self.bond_calculator.get_bonds(val) + + msg[i] = znjson.dumps( + val, cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ) if len(json.dumps(msg).encode("utf-8")) > self.maximum_message_size: call_with_retry( self.socket, @@ -539,7 +570,10 @@ def __getitem__(self, index: int | list | slice) -> ase.Atoms | list[ase.Atoms]: except IndexError: data = [] - structures = [znframe.Frame.from_json(x).to_atoms() for x in data] + structures = [ + znjson.loads(x, cls=znjson.ZnDecoder.from_converters([ASEConverter])) + for x in data + ] if single_item: return structures[0] return structures