From 662d5669247d4a8fcffd63dcf5f53ebced0dc7e6 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 10:38:53 +0200 Subject: [PATCH 01/17] add znjson converter --- poetry.lock | 10 ++-- pyproject.toml | 1 + tests/test_serializer.py | 11 ++++ zndraw/utils.py | 107 ++++++++++++++++++++++----------------- 4 files changed, 78 insertions(+), 51 deletions(-) create mode 100644 tests/test_serializer.py diff --git a/poetry.lock b/poetry.lock index 8396630fc..714ae9da5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -972,9 +972,9 @@ isort = ">=4.3.21,<6.0" jinja2 = ">=2.10.1,<4.0" packaging = "*" pydantic = [ - {version = ">=1.9.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.10.0,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.12\" and python_version < \"4.0\""}, + {version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.9.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] pyyaml = ">=6.0.1" toml = {version = ">=0.10.0,<1.0.0", markers = "python_version < \"3.11\""} @@ -3021,9 +3021,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4903,4 +4903,4 @@ rdkit = ["rdkit2ase"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "018f62b93cd7caba2762e1242d2ac48f03f403cd17bbdac6279f2d151ec4ccca" +content-hash = "28904682037b0f015c31ee09e1451c0db226f8a83e92dfe0a028323deba93ff1" diff --git a/pyproject.toml b/pyproject.toml index 4f49cb2e6..debbd46d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ redis = "^5" splines = "^0.3" znsocket = "^0.1.6" lazy-loader = "^0.4" +znjson = "^0.2.3" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_serializer.py b/tests/test_serializer.py new file mode 100644 index 000000000..b2bfb9d60 --- /dev/null +++ b/tests/test_serializer.py @@ -0,0 +1,11 @@ +import json + +import znjson + +from zndraw.utils import ASEConverter + + +def test_ase_converter(s22): + atoms_json = json.dumps(s22[0], cls=znjson.ZnEncoder.from_converters([ASEConverter])) + atoms = json.loads(atoms_json, cls=znjson.ZnDecoder.from_converters([ASEConverter])) + assert atoms == s22[0] diff --git a/zndraw/utils.py b/zndraw/utils.py index 0001c6092..bc03251bf 100644 --- a/zndraw/utils.py +++ b/zndraw/utils.py @@ -8,12 +8,73 @@ import typing as t import uuid +import ase import datamodel_code_generator import socketio.exceptions +from znjson import ConverterBase log = logging.getLogger(__name__) +class ASEDict(t.TypedDict): + numbers: list[int] + positions: list[list[float]] + connectivity: list[tuple[int, int, int]] + arrays: dict[str, list[float | int | list[float | int]]] + info: dict[str, float | int] + # calc: dict[str, float|int|np.ndarray] # should this be split into arrays and info? + pbc: list[bool] + cell: list[list[float]] + + +class ASEConverter(ConverterBase): + """Encode/Decode datetime objects + + Attributes + ---------- + level: int + Priority of this converter over others. + A higher level will be used first, if there + are multiple converters available + representation: str + An unique identifier for this converter. + instance: + Used to select the correct converter. + This should fulfill isinstance(other, self.instance) + or __eq__ should be overwritten. + """ + + level = 100 + representation = "ase.Atoms" + instance = ase.Atoms + + def encode(self, obj: ase.Atoms) -> ASEDict: + """Convert the datetime object to str / isoformat""" + return ASEDict( + numbers=obj.numbers.tolist(), + positions=obj.positions.tolist(), + connectivity=None, + arrays={}, + info={}, + # calc=obj.calc.results, + pbc=obj.pbc.tolist(), + cell=obj.cell.tolist(), + ) + + def decode(self, value: ASEDict) -> ase.Atoms: + """Create datetime object from str / isoformat""" + return ase.Atoms( + numbers=value["numbers"], + positions=value["positions"], + # connectivity=value["connectivity"], + # arrays=value["arrays"], + info=value["info"], + # calc=value["calc"], + pbc=value["pbc"], + cell=value["cell"], + ) + + def get_port(default: int) -> int: """Get an open port.""" try: @@ -78,52 +139,6 @@ def get_cls_from_json_schema(schema: dict, name: str, **kwargs): return getattr(module, name) -def ensure_path(path: str): - """Ensure that a path exists.""" - p = pathlib.Path(path).expanduser() - p.mkdir(parents=True, exist_ok=True) - return p.as_posix() - - -def wrap_and_check_index(index: int | slice | list[int], length: int) -> list[int]: - is_slice = isinstance(index, slice) - if is_slice: - index = list(range(*index.indices(length))) - index = [index] if isinstance(index, int) else index - index = [i if i >= 0 else length + i for i in index] - # check if index is out of range - for i in index: - if i >= length: - raise IndexError(f"Index {i} out of range for length {length}") - if i < 0: - raise IndexError(f"Index {i-length} out of range for length {length}") - return index - - -def check_selection(value: list[int], maximum: int): - """Check if the selection is valid - - Attributes - ---------- - value: list[int] - the selected indices - maximum: int - len(vis.step), will be incremented by one, to account for - """ - if not isinstance(value, list): - raise ValueError("Selection must be a list") - if any(not isinstance(i, int) for i in value): - raise ValueError("Selection must be a list of integers") - if len(value) != len(set(value)): - raise ValueError("Selection must be unique") - if any(i < 0 for i in value): - raise ValueError("Selection must be positive integers") - if any(i >= maximum for i in value): - raise ValueError( - f"Can not select particles indices larger than size of the scene: {maximum }. Got {value}" - ) - - def emit_with_retry( socket: socketio.Client, event, From b4d1029a5101329d9d15c7b9291dcb6da7ded303 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 11:01:06 +0200 Subject: [PATCH 02/17] complete ASEConverter --- tests/test_serializer.py | 29 +++++++++++-- zndraw/utils.py | 92 ++++++++++++++++++++++++++++++++++------ 2 files changed, 106 insertions(+), 15 deletions(-) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index b2bfb9d60..e196eebe0 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,11 +1,34 @@ import json +import pytest import znjson +from ase.calculators.singlepoint import SinglePointCalculator from zndraw.utils import ASEConverter def test_ase_converter(s22): - atoms_json = json.dumps(s22[0], cls=znjson.ZnEncoder.from_converters([ASEConverter])) - atoms = json.loads(atoms_json, cls=znjson.ZnDecoder.from_converters([ASEConverter])) - assert atoms == s22[0] + s22[0].connectivity = [[0, 1, 1], [1, 2, 1], [2, 3, 1]] + s22[3].calc = SinglePointCalculator(s22[3]) + s22[3].calc.results = {"energy": 0.0, "predicted_energy": 1.0} + s22[4].info = {"key": "value"} + + structures_json = json.dumps( + s22, cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ) + structures = json.loads( + structures_json, cls=znjson.ZnDecoder.from_converters([ASEConverter]) + ) + for s1, s2 in zip(s22, structures): + assert s1 == s2 + + assert structures[0].connectivity == [[0, 1, 1], [1, 2, 1], [2, 3, 1]] + with pytest.raises(AttributeError): + _ = structures[1].connectivity + + assert structures[3].calc.results == {"energy": 0.0, "predicted_energy": 1.0} + + assert "colors" in structures[0].arrays + assert "radii" in structures[0].arrays + + assert structures[4].info == {"key": "value"} diff --git a/zndraw/utils.py b/zndraw/utils.py index bc03251bf..432b527e3 100644 --- a/zndraw/utils.py +++ b/zndraw/utils.py @@ -10,7 +10,11 @@ import ase import datamodel_code_generator +import numpy as np import socketio.exceptions +from ase.calculators.singlepoint import SinglePointCalculator +from ase.data.colors import jmol_colors +from ase.data.vdw import vdw_radii from znjson import ConverterBase log = logging.getLogger(__name__) @@ -27,6 +31,11 @@ class ASEDict(t.TypedDict): cell: list[list[float]] +def rgb2hex(value): + r, g, b = np.array(value * 255, dtype=int) + return "#%02x%02x%02x" % (r, g, b) + + class ASEConverter(ConverterBase): """Encode/Decode datetime objects @@ -50,29 +59,88 @@ class ASEConverter(ConverterBase): def encode(self, obj: ase.Atoms) -> ASEDict: """Convert the datetime object to str / isoformat""" + + numbers = obj.numbers.tolist() + positions = obj.positions.tolist() + pbc = obj.pbc.tolist() + cell = obj.cell.tolist() + + info = { + k: v + for k, v in obj.info.items() + if isinstance(v, (float, int, str, bool, list)) + } + info |= {k: v.tolist() for k, v in obj.info.items() if isinstance(v, np.ndarray)} + + if obj.calc is not None: + calc = { + k: v + for k, v in obj.calc.results.items() + if isinstance(v, (float, int, str, bool, list)) + } + calc |= { + k: v.tolist() + for k, v in obj.calc.results.items() + if isinstance(v, np.ndarray) + } + else: + calc = {} + + arrays = {} + if "colors" not in obj.arrays: + arrays["colors"] = [rgb2hex(jmol_colors[number]) for number in numbers] + else: + arrays["colors"] = ( + obj.arrays["colors"].tolist() + if isinstance(obj.arrays["colors"], np.ndarray) + else obj.arrays["colors"] + ) + + if "radii" not in obj.arrays: + arrays["radii"] = [vdw_radii[number] for number in numbers] + else: + arrays["radii"] = ( + obj.arrays["radii"].tolist() + if isinstance(obj.arrays["radii"], np.ndarray) + else obj.arrays["radii"] + ) + + if hasattr(obj, "connectivity") and obj.connectivity is not None: + connectivity = ( + obj.connectivity.tolist() + if isinstance(obj.connectivity, np.ndarray) + else obj.connectivity + ) + else: + connectivity = [] + return ASEDict( - numbers=obj.numbers.tolist(), - positions=obj.positions.tolist(), - connectivity=None, - arrays={}, - info={}, - # calc=obj.calc.results, - pbc=obj.pbc.tolist(), - cell=obj.cell.tolist(), + numbers=numbers, + positions=positions, + connectivity=connectivity, + arrays=arrays, + info=info, + calc=calc, + pbc=pbc, + cell=cell, ) def decode(self, value: ASEDict) -> ase.Atoms: """Create datetime object from str / isoformat""" - return ase.Atoms( + atoms = ase.Atoms( numbers=value["numbers"], positions=value["positions"], - # connectivity=value["connectivity"], - # arrays=value["arrays"], info=value["info"], - # calc=value["calc"], pbc=value["pbc"], cell=value["cell"], ) + if connectivity := value.get("connectivity"): + atoms.connectivity = connectivity + atoms.arrays.update(value["arrays"]) + if calc := value.get("calc"): + atoms.calc = SinglePointCalculator(atoms) + atoms.calc.results.update(calc) + return atoms def get_port(default: int) -> int: From 35854865d741447001e1019e01fd86678a5e4612 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 11:18:20 +0200 Subject: [PATCH 03/17] replace znframe with znjson --- tests/test_serializer.py | 6 ++---- zndraw/base.py | 8 ++++++-- zndraw/modify/__init__.py | 5 ++--- zndraw/scene.py | 8 ++++++-- zndraw/tasks/__init__.py | 8 +++++--- zndraw/utils.py | 2 ++ zndraw/zndraw.py | 39 ++++++++++++++++++++++++++++++--------- 7 files changed, 53 insertions(+), 23 deletions(-) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index e196eebe0..ac3316d7f 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,5 +1,3 @@ -import json - import pytest import znjson from ase.calculators.singlepoint import SinglePointCalculator @@ -13,10 +11,10 @@ def test_ase_converter(s22): s22[3].calc.results = {"energy": 0.0, "predicted_energy": 1.0} s22[4].info = {"key": "value"} - structures_json = json.dumps( + structures_json = znjson.dumps( s22, cls=znjson.ZnEncoder.from_converters([ASEConverter]) ) - structures = json.loads( + structures = znjson.loads( structures_json, cls=znjson.ZnDecoder.from_converters([ASEConverter]) ) for s1, s2 in zip(s22, structures): diff --git a/zndraw/base.py b/zndraw/base.py index 9bdb992fb..4f9194bfd 100644 --- a/zndraw/base.py +++ b/zndraw/base.py @@ -7,12 +7,14 @@ import ase import numpy as np import splines -import znframe +import znjson import znsocket from flask import current_app, session from pydantic import BaseModel, Field, create_model from redis import Redis +from zndraw.utils import ASEConverter + log = logging.getLogger(__name__) @@ -42,7 +44,9 @@ def get_atoms() -> ase.Atoms: lst = znsocket.List(r, key) try: frame_json = lst[int(step)] - return znframe.Frame.from_json(frame_json).to_atoms() + return znjson.loads( + frame_json, cls=znjson.ZnDecoder.from_converters([ASEConverter]) + ) except TypeError: # step is None return ase.Atoms() diff --git a/zndraw/modify/__init__.py b/zndraw/modify/__init__.py index 5ca95c450..434fb1194 100644 --- a/zndraw/modify/__init__.py +++ b/zndraw/modify/__init__.py @@ -8,7 +8,6 @@ import numpy as np from ase.data import chemical_symbols from pydantic import Field -from znframe.frame import get_radius from zndraw.base import Extension, MethodsCollection @@ -50,13 +49,13 @@ class Connect(UpdateScene): """Create guiding curve between selected atoms.""" def run(self, vis: "ZnDraw", **kwargs) -> None: + atoms = vis.atoms atom_ids = vis.selection atom_positions = vis.atoms.get_positions() - atom_numbers = vis.atoms.numbers[atom_ids] camera_position = np.array(vis.camera["position"])[None, :] # 1,3 new_points = atom_positions[atom_ids] # N, 3 - radii: np.ndarray = get_radius(atom_numbers)[0][:, None] # N, 1 + radii: np.ndarray = atoms.arrays["radii"][atom_ids] direction = camera_position - new_points direction /= np.linalg.norm(direction, axis=1, keepdims=True) new_points += direction * radii diff --git a/zndraw/scene.py b/zndraw/scene.py index b4e5bd164..9523ffa3d 100644 --- a/zndraw/scene.py +++ b/zndraw/scene.py @@ -1,12 +1,14 @@ import enum import ase -import znframe +import znjson import znsocket from flask import current_app, session from pydantic import BaseModel, Field from redis import Redis +from zndraw.utils import ASEConverter + class Material(str, enum.Enum): MeshBasicMaterial = "MeshBasicMaterial" @@ -74,7 +76,9 @@ def _get_atoms() -> ase.Atoms: lst = znsocket.List(r, key) try: frame_json = lst[int(step)] - return znframe.Frame.from_json(frame_json).to_atoms() + return znjson.loads( + frame_json, cls=znjson.ZnDecoder.from_converters([ASEConverter]) + ) except TypeError: # step is None return ase.Atoms() diff --git a/zndraw/tasks/__init__.py b/zndraw/tasks/__init__.py index 7e13bb1b4..1c0a2cad9 100644 --- a/zndraw/tasks/__init__.py +++ b/zndraw/tasks/__init__.py @@ -4,12 +4,13 @@ import ase.io import socketio.exceptions import tqdm -import znframe +import znjson import znsocket from celery import shared_task from flask import current_app from zndraw.base import FileIO +from zndraw.utils import ASEConverter log = logging.getLogger(__name__) @@ -77,8 +78,9 @@ def _generator(): break if file_io.step and idx % file_io.step != 0: continue - frame = znframe.Frame.from_atoms(atoms) - lst.append(frame.to_json()) + lst.append( + znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])) + ) if idx == 0: try: io.connect(current_app.config["SERVER_URL"], wait_timeout=10) diff --git a/zndraw/utils.py b/zndraw/utils.py index 432b527e3..394c9699f 100644 --- a/zndraw/utils.py +++ b/zndraw/utils.py @@ -86,6 +86,8 @@ def encode(self, obj: ase.Atoms) -> ASEDict: else: calc = {} + # All additional information should be stored in calc.results + # and not in calc.arrays, thus we will not convert it here! arrays = {} if "colors" not in obj.arrays: arrays["colors"] = [rgb2hex(jmol_colors[number]) for number in numbers] diff --git a/zndraw/zndraw.py b/zndraw/zndraw.py index 0b026b96e..21fe26f55 100644 --- a/zndraw/zndraw.py +++ b/zndraw/zndraw.py @@ -8,13 +8,13 @@ import numpy as np import socketio.exceptions import tqdm -import znframe +import znjson import znsocket from redis import Redis from zndraw.base import Extension, ZnDrawBase from zndraw.draw import Geometry, Object3D -from zndraw.utils import call_with_retry, emit_with_retry +from zndraw.utils import ASEConverter, call_with_retry, emit_with_retry log = logging.getLogger(__name__) @@ -154,7 +154,10 @@ def __getitem__(self, index) -> ase.Atoms | list[ase.Atoms]: retries=self.timeout["call_retries"], ) - structures = [znframe.Frame(**x).to_atoms() for x in data.values()] + structures = [ + znjson.loads(x, cls=znjson.ZnDecoder.from_converters([ASEConverter])) + for x in data.values() + ] return structures[0] if single_item else structures def __setitem__( @@ -163,10 +166,15 @@ def __setitem__( if isinstance(index, slice): index = list(range(*index.indices(len(self)))) if isinstance(index, int): - data = {index: znframe.Frame.from_atoms(value).to_json()} + data = { + index: znjson.dumps( + value, cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ) + } else: data = { - i: znframe.Frame.from_atoms(val).to_json() for i, val in zip(index, value) + i: znjson.dumps(val, cls=znjson.ZnEncoder.from_converters([ASEConverter])) + for i, val in zip(index, value) } call_with_retry( @@ -202,7 +210,12 @@ def insert(self, index: int, value: ase.Atoms): call_with_retry( self.socket, "room:frames:insert", - {"index": index, "value": znframe.Frame.from_atoms(value).to_json()}, + { + "index": index, + "value": znjson.dumps( + value, cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ), + }, retries=self.timeout["call_retries"], ) @@ -211,10 +224,13 @@ def extend(self, values: list[ase.Atoms]): # enable tbar if more than 10 messages are sent # approximated by the size of the first frame + show_tbar = ( len(values) * len( - json.dumps(znframe.Frame.from_atoms(values[0]).to_json()).encode("utf-8") + znjson.dumps( + values[0], cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ).encode("utf-8") ) ) > (10 * self.maximum_message_size) tbar = tqdm.tqdm( @@ -222,7 +238,9 @@ def extend(self, values: list[ase.Atoms]): ) for i, val in enumerate(tbar, start=len(self)): - msg[i] = znframe.Frame.from_atoms(val).to_json() + msg[i] = znjson.dumps( + val, cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ) if len(json.dumps(msg).encode("utf-8")) > self.maximum_message_size: call_with_retry( self.socket, @@ -535,7 +553,10 @@ def __getitem__(self, index: int | list | slice) -> ase.Atoms | list[ase.Atoms]: except IndexError: data = [] - structures = [znframe.Frame.from_json(x).to_atoms() for x in data] + structures = [ + znjson.loads(x, cls=znjson.ZnDecoder.from_converters([ASEConverter])) + for x in data + ] if single_item: return structures[0] return structures From 71124432cab053eb7ef8bbbafdaf65aa30c81384 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 11:36:52 +0200 Subject: [PATCH 04/17] replace znframe everywhere --- app/src/App.tsx | 5 +++-- app/src/components/particles.tsx | 2 +- zndraw/modify/__init__.py | 5 +++-- zndraw/utils.py | 9 ++++++--- zndraw/zndraw.py | 4 +++- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/app/src/App.tsx b/app/src/App.tsx index 5ecd1e3b2..0f3931250 100644 --- a/app/src/App.tsx +++ b/app/src/App.tsx @@ -134,17 +134,18 @@ export default function App() { // if step changes useEffect(() => { socket.emit("room:frames:get", [step], (frames: Frames) => { + console.log(frames); // map positions: numbers[][] to THREE.Vector3[] for (const key in frames) { if (frames.hasOwnProperty(key)) { - const frame = frames[key]; + const frame: Frame = frames[key]["value"]; frame.positions = frame.positions.map( (position) => new THREE.Vector3(position[0], position[1], position[2]), ) as THREE.Vector3[]; } - setCurrentFrame(frames[step]); + setCurrentFrame(frames[step]["value"]); setNeedsUpdate(false); // rename this to something more descriptive } }); diff --git a/app/src/components/particles.tsx b/app/src/components/particles.tsx index fbdd57eb8..27cbb328d 100644 --- a/app/src/components/particles.tsx +++ b/app/src/components/particles.tsx @@ -17,7 +17,7 @@ export interface Frame { ``; export interface Frames { - [key: number]: Frame; + [key: number]: { _type: string; value: Frame }; } export const Player = ({ diff --git a/zndraw/modify/__init__.py b/zndraw/modify/__init__.py index 434fb1194..65bbf9250 100644 --- a/zndraw/modify/__init__.py +++ b/zndraw/modify/__init__.py @@ -55,7 +55,7 @@ def run(self, vis: "ZnDraw", **kwargs) -> None: camera_position = np.array(vis.camera["position"])[None, :] # 1,3 new_points = atom_positions[atom_ids] # N, 3 - radii: np.ndarray = atoms.arrays["radii"][atom_ids] + radii: np.ndarray = atoms.arrays["radii"][atom_ids][:, None] direction = camera_position - new_points direction /= np.linalg.norm(direction, axis=1, keepdims=True) new_points += direction * radii @@ -117,7 +117,8 @@ def run(self, vis: "ZnDraw", **kwargs) -> None: else: for idx, atom_id in enumerate(sorted(atom_ids)): atoms.pop(atom_id - idx) # we remove the atom and shift the index - del atoms.connectivity + if hasattr(atoms, "connectivity"): + del atoms.connectivity vis.append(atoms) vis.selection = [] vis.step += 1 diff --git a/zndraw/utils.py b/zndraw/utils.py index 394c9699f..6b98fa147 100644 --- a/zndraw/utils.py +++ b/zndraw/utils.py @@ -13,8 +13,8 @@ import numpy as np import socketio.exceptions from ase.calculators.singlepoint import SinglePointCalculator +from ase.data import covalent_radii from ase.data.colors import jmol_colors -from ase.data.vdw import vdw_radii from znjson import ConverterBase log = logging.getLogger(__name__) @@ -99,7 +99,7 @@ def encode(self, obj: ase.Atoms) -> ASEDict: ) if "radii" not in obj.arrays: - arrays["radii"] = [vdw_radii[number] for number in numbers] + arrays["radii"] = [covalent_radii[number] for number in numbers] else: arrays["radii"] = ( obj.arrays["radii"].tolist() @@ -138,7 +138,10 @@ def decode(self, value: ASEDict) -> ase.Atoms: ) if connectivity := value.get("connectivity"): atoms.connectivity = connectivity - atoms.arrays.update(value["arrays"]) + if "colors" in value["arrays"]: + atoms.arrays["colors"] = np.array(value["arrays"]["colors"]) + if "radii" in value["arrays"]: + atoms.arrays["radii"] = np.array(value["arrays"]["radii"]) if calc := value.get("calc"): atoms.calc = SinglePointCalculator(atoms) atoms.calc.results.update(calc) diff --git a/zndraw/zndraw.py b/zndraw/zndraw.py index 21fe26f55..ae09959a8 100644 --- a/zndraw/zndraw.py +++ b/zndraw/zndraw.py @@ -155,7 +155,9 @@ def __getitem__(self, index) -> ase.Atoms | list[ase.Atoms]: ) structures = [ - znjson.loads(x, cls=znjson.ZnDecoder.from_converters([ASEConverter])) + znjson.loads( + json.dumps(x), cls=znjson.ZnDecoder.from_converters([ASEConverter]) + ) for x in data.values() ] return structures[0] if single_item else structures From 5df7c332e496a25c9aee3f0511ed69655ef44784 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 11:38:01 +0200 Subject: [PATCH 05/17] remove znframe --- poetry.lock | 36 +++++++++--------------------------- pyproject.toml | 1 - 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/poetry.lock b/poetry.lock index 714ae9da5..04db9c825 100644 --- a/poetry.lock +++ b/poetry.lock @@ -284,7 +284,7 @@ files = [ name = "attrs" version = "23.2.0" description = "Classes Without Boilerplate" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, @@ -972,9 +972,9 @@ isort = ">=4.3.21,<6.0" jinja2 = ">=2.10.1,<4.0" packaging = "*" pydantic = [ - {version = ">=1.10.0,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.12\" and python_version < \"4.0\""}, - {version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.9.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.10.0,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.12\" and python_version < \"4.0\""}, ] pyyaml = ">=6.0.1" toml = {version = ">=0.10.0,<1.0.0", markers = "python_version < \"3.11\""} @@ -3021,9 +3021,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4790,33 +4790,15 @@ networkx = ">=3.0,<4.0" [package.extras] dask = ["bokeh (>=2.4.2,<3.0.0)", "dask (>=2022.12.1,<2023.0.0)", "dask-jobqueue (>=0.8.1,<0.9.0)", "distributed (>=2022.12.1,<2023.0.0)"] -[[package]] -name = "znframe" -version = "0.1.5" -description = "ZnFrame - ASE-like Interface based on dataclasses" -optional = false -python-versions = ">=3.9,<4.0" -files = [ - {file = "znframe-0.1.5-py3-none-any.whl", hash = "sha256:3645ee836b19f3d957686ba58f2fd4c142afef44df33764856848579e3af61f4"}, - {file = "znframe-0.1.5.tar.gz", hash = "sha256:b5424e196af4df7fc8662c8ce6be5950a1307b763d4159bad69a0559448f6acf"}, -] - -[package.dependencies] -ase = ">=3.22.1,<4.0.0" -attrs = ">=23.1.0,<24.0.0" -networkx = ">=3.2.1,<4.0.0" -numpy = ">=1.26.2,<2.0.0" -pydantic = ">=2.5.3,<3.0.0" - [[package]] name = "znh5md" -version = "0.2.0" +version = "0.2.1" description = "High Performance Interface for H5MD Trajectories" optional = true python-versions = "<4.0,>=3.9" files = [ - {file = "znh5md-0.2.0-py3-none-any.whl", hash = "sha256:563e5c9b6a1e29864d31c4731416eb4b79f7acfee23cbe8cc53730b247d1125b"}, - {file = "znh5md-0.2.0.tar.gz", hash = "sha256:da60ca2e53f2d74cdc3a5954dbadae912973e4a29456329d2e6406f8d5e794bf"}, + {file = "znh5md-0.2.1-py3-none-any.whl", hash = "sha256:73a4393544ed791d6aaa3850a30e14dd4f04e820c3cf5ea5f0aa74eab1be52d0"}, + {file = "znh5md-0.2.1.tar.gz", hash = "sha256:aa10de0219f0abc0ab685c1e3d91b556ea9e03cca8e08c8d491e50dcd511cb0b"}, ] [package.dependencies] @@ -4903,4 +4885,4 @@ rdkit = ["rdkit2ase"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "28904682037b0f015c31ee09e1451c0db226f8a83e92dfe0a028323deba93ff1" +content-hash = "76d89fb129c3dfe3da2cc64d4f88c05e1c1722008287e4dafa40207f16aa06f6" diff --git a/pyproject.toml b/pyproject.toml index debbd46d8..5fcd83f19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,6 @@ mdanalysis = {version = "^2", optional = true} tidynamics = {version = "^1", optional = true} rdkit2ase = {version = "^0.1", optional = true} zntrack = {version = "^0.7.3", optional = true} -znframe = "^0.1" celery = "^5" sqlalchemy = "^2" redis = "^5" From c5ed6e90b05a5c6f78b712e5acf7b150a881003f Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 11:55:25 +0200 Subject: [PATCH 06/17] compute bonds after reading the file --- zndraw/bonds/__init__.py | 5 +++-- zndraw/cli.py | 8 ++++++-- zndraw/tasks/__init__.py | 20 ++++++++++++++++++++ zndraw/utils.py | 3 ++- 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/zndraw/bonds/__init__.py b/zndraw/bonds/__init__.py index e5aff3d85..6dfc5e9b2 100644 --- a/zndraw/bonds/__init__.py +++ b/zndraw/bonds/__init__.py @@ -47,8 +47,9 @@ def remove_edge(graph, atom_1, atom_2): except NetworkXError: pass - def get_bonds(self, atoms: ase.Atoms): - graph = atoms.connectivity + def get_bonds(self, atoms: ase.Atoms, graph: nx.Graph = None): + if graph is None: + graph = self.build_graph(atoms) bonds = [] for edge in graph.edges: bonds.append((edge[0], edge[1], graph.edges[edge]["weight"])) diff --git a/zndraw/cli.py b/zndraw/cli.py index 836f5f78c..4ad648bc5 100644 --- a/zndraw/cli.py +++ b/zndraw/cli.py @@ -6,11 +6,12 @@ import typing as t import typer +from celery import chain from zndraw.app import create_app from zndraw.base import FileIO from zndraw.standalone import run_celery_worker, run_znsocket -from zndraw.tasks import read_file +from zndraw.tasks import compute_bonds, read_file from zndraw.utils import get_port cli = typer.Typer() @@ -122,7 +123,10 @@ def main( app = create_app() - read_file.delay(fileio.to_dict()) + # read_file.delay(fileio.to_dict()) + # compute_bonds.delay() + + chain(read_file.s(fileio.to_dict()), compute_bonds.s()).apply_async() if browser: import webbrowser diff --git a/zndraw/tasks/__init__.py b/zndraw/tasks/__init__.py index 1c0a2cad9..4a5310185 100644 --- a/zndraw/tasks/__init__.py +++ b/zndraw/tasks/__init__.py @@ -10,6 +10,7 @@ from flask import current_app from zndraw.base import FileIO +from zndraw.bonds import ASEComputeBonds from zndraw.utils import ASEConverter log = logging.getLogger(__name__) @@ -103,6 +104,25 @@ def _generator(): io.disconnect() +@shared_task +def compute_bonds(room=None) -> None: + from zndraw.zndraw import ZnDrawLocal + + vis = ZnDrawLocal( + r=current_app.extensions["redis"], + url=current_app.config["SERVER_URL"], + token="default" if room is None else room, + ) + + bonds_calculator = ASEComputeBonds() + for idx, atoms in enumerate(vis): + try: + atoms.connectivity = bonds_calculator.get_bonds(atoms) + vis[idx] = atoms + except Exception as e: + vis.log(str(e)) + + @shared_task def run_modifier(room, data: dict) -> None: from zndraw.modify import Modifier diff --git a/zndraw/utils.py b/zndraw/utils.py index 6b98fa147..ea20af07d 100644 --- a/zndraw/utils.py +++ b/zndraw/utils.py @@ -137,7 +137,8 @@ def decode(self, value: ASEDict) -> ase.Atoms: cell=value["cell"], ) if connectivity := value.get("connectivity"): - atoms.connectivity = connectivity + # or do we want this to be nx.Graph? + atoms.connectivity = np.array(connectivity) if "colors" in value["arrays"]: atoms.arrays["colors"] = np.array(value["arrays"]["colors"]) if "radii" in value["arrays"]: From 0d9a67b171bdcbb55a69c02b17e03a296b84c048 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 12:04:35 +0200 Subject: [PATCH 07/17] connectivit is numpy --- tests/test_serializer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index ac3316d7f..aacb6b07c 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,3 +1,4 @@ +import numpy.testing as npt import pytest import znjson from ase.calculators.singlepoint import SinglePointCalculator @@ -20,7 +21,7 @@ def test_ase_converter(s22): for s1, s2 in zip(s22, structures): assert s1 == s2 - assert structures[0].connectivity == [[0, 1, 1], [1, 2, 1], [2, 3, 1]] + npt.assert_array_equal(structures[0].connectivity, [[0, 1, 1], [1, 2, 1], [2, 3, 1]]) with pytest.raises(AttributeError): _ = structures[1].connectivity From 2c7ca488ef8cbfddc4f1c0db6376ba20576bd7ee Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 16:30:22 +0200 Subject: [PATCH 08/17] test exotic atoms --- tests/test_serializer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index aacb6b07c..f975150fe 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -2,6 +2,7 @@ import pytest import znjson from ase.calculators.singlepoint import SinglePointCalculator +import ase from zndraw.utils import ASEConverter @@ -31,3 +32,13 @@ def test_ase_converter(s22): assert "radii" in structures[0].arrays assert structures[4].info == {"key": "value"} + + +def test_exotic_atoms(): + atoms = ase.Atoms("X", positions=[[0, 0, 0]]) + new_atoms = znjson.loads( + znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])), + cls=znjson.ZnDecoder.from_converters([ASEConverter]), + ) + npt.assert_array_equal(new_atoms.arrays["colors"], ['#ff0000']) + npt.assert_array_equal(new_atoms.arrays["radii"], [0.2]) From a8ac170e76c234701af655a1810cc2494b880f04 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 16:33:51 +0200 Subject: [PATCH 09/17] compute bonds directly --- zndraw/cli.py | 8 ++------ zndraw/tasks/__init__.py | 22 +++------------------- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/zndraw/cli.py b/zndraw/cli.py index 4ad648bc5..836f5f78c 100644 --- a/zndraw/cli.py +++ b/zndraw/cli.py @@ -6,12 +6,11 @@ import typing as t import typer -from celery import chain from zndraw.app import create_app from zndraw.base import FileIO from zndraw.standalone import run_celery_worker, run_znsocket -from zndraw.tasks import compute_bonds, read_file +from zndraw.tasks import read_file from zndraw.utils import get_port cli = typer.Typer() @@ -123,10 +122,7 @@ def main( app = create_app() - # read_file.delay(fileio.to_dict()) - # compute_bonds.delay() - - chain(read_file.s(fileio.to_dict()), compute_bonds.s()).apply_async() + read_file.delay(fileio.to_dict()) if browser: import webbrowser diff --git a/zndraw/tasks/__init__.py b/zndraw/tasks/__init__.py index 4a5310185..7fef23fdc 100644 --- a/zndraw/tasks/__init__.py +++ b/zndraw/tasks/__init__.py @@ -39,6 +39,7 @@ def read_file(fileio: dict) -> None: r.delete("room:default:frames") lst = znsocket.List(r, "room:default:frames") + bonds_calculator = ASEComputeBonds() if file_io.name is None: @@ -79,6 +80,8 @@ def _generator(): break if file_io.step and idx % file_io.step != 0: continue + if not hasattr(atoms, "connectivity"): + atoms.connectivity = bonds_calculator.get_bonds(atoms) lst.append( znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])) ) @@ -104,25 +107,6 @@ def _generator(): io.disconnect() -@shared_task -def compute_bonds(room=None) -> None: - from zndraw.zndraw import ZnDrawLocal - - vis = ZnDrawLocal( - r=current_app.extensions["redis"], - url=current_app.config["SERVER_URL"], - token="default" if room is None else room, - ) - - bonds_calculator = ASEComputeBonds() - for idx, atoms in enumerate(vis): - try: - atoms.connectivity = bonds_calculator.get_bonds(atoms) - vis[idx] = atoms - except Exception as e: - vis.log(str(e)) - - @shared_task def run_modifier(room, data: dict) -> None: from zndraw.modify import Modifier From eb803bc013eef769fa81e90395b16d453e47a888 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 16:37:34 +0200 Subject: [PATCH 10/17] compute bonds in vis object --- zndraw/zndraw.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/zndraw/zndraw.py b/zndraw/zndraw.py index ae09959a8..5a1081326 100644 --- a/zndraw/zndraw.py +++ b/zndraw/zndraw.py @@ -15,6 +15,7 @@ from zndraw.base import Extension, ZnDrawBase from zndraw.draw import Geometry, Object3D from zndraw.utils import ASEConverter, call_with_retry, emit_with_retry +from zndraw.bonds import ASEComputeBonds log = logging.getLogger(__name__) @@ -86,6 +87,10 @@ class ZnDraw(ZnDrawBase): default_factory=datetime.datetime.now ) + bond_calculator: ASEComputeBonds = dataclasses.field( + default_factory=ASEComputeBonds, repr=False + ) + def __post_init__(self): def on_wakeup(): if self._available: @@ -168,16 +173,14 @@ def __setitem__( if isinstance(index, slice): index = list(range(*index.indices(len(self)))) if isinstance(index, int): - data = { - index: znjson.dumps( - value, cls=znjson.ZnEncoder.from_converters([ASEConverter]) - ) - } - else: - data = { - i: znjson.dumps(val, cls=znjson.ZnEncoder.from_converters([ASEConverter])) - for i, val in zip(index, value) - } + index = [index] + value = [value] + + data = {} + for i, val in zip(index, value): + if not hasattr(val, "connectivity"): + val.connectivity = self.bond_calculator.get_bonds(val) + data[i] = znjson.dumps(val, cls=znjson.ZnEncoder.from_converters([ASEConverter])) call_with_retry( self.socket, @@ -240,6 +243,9 @@ def extend(self, values: list[ase.Atoms]): ) for i, val in enumerate(tbar, start=len(self)): + if not hasattr(val, "connectivity"): + val.connectivity = self.bond_calculator.get_bonds(val) + msg[i] = znjson.dumps( val, cls=znjson.ZnEncoder.from_converters([ASEConverter]) ) From 1adb2f2d16e85708d6fad048f9ea030f59631365 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 16:38:27 +0200 Subject: [PATCH 11/17] allow to disable bond calculation --- zndraw/zndraw.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zndraw/zndraw.py b/zndraw/zndraw.py index 5a1081326..a0bf4bf27 100644 --- a/zndraw/zndraw.py +++ b/zndraw/zndraw.py @@ -87,7 +87,7 @@ class ZnDraw(ZnDrawBase): default_factory=datetime.datetime.now ) - bond_calculator: ASEComputeBonds = dataclasses.field( + bond_calculator: ASEComputeBonds|None = dataclasses.field( default_factory=ASEComputeBonds, repr=False ) @@ -178,7 +178,7 @@ def __setitem__( data = {} for i, val in zip(index, value): - if not hasattr(val, "connectivity"): + if not hasattr(val, "connectivity") and self.bond_calculator is not None: val.connectivity = self.bond_calculator.get_bonds(val) data[i] = znjson.dumps(val, cls=znjson.ZnEncoder.from_converters([ASEConverter])) @@ -243,7 +243,7 @@ def extend(self, values: list[ase.Atoms]): ) for i, val in enumerate(tbar, start=len(self)): - if not hasattr(val, "connectivity"): + if not hasattr(val, "connectivity") and self.bond_calculator is not None: val.connectivity = self.bond_calculator.get_bonds(val) msg[i] = znjson.dumps( From e69a33b0f49ba0c52081508f061a2ae08e3c1b65 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Jun 2024 14:38:55 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_serializer.py | 4 ++-- zndraw/zndraw.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index f975150fe..93b07fb94 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,8 +1,8 @@ +import ase import numpy.testing as npt import pytest import znjson from ase.calculators.singlepoint import SinglePointCalculator -import ase from zndraw.utils import ASEConverter @@ -40,5 +40,5 @@ def test_exotic_atoms(): znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])), cls=znjson.ZnDecoder.from_converters([ASEConverter]), ) - npt.assert_array_equal(new_atoms.arrays["colors"], ['#ff0000']) + npt.assert_array_equal(new_atoms.arrays["colors"], ["#ff0000"]) npt.assert_array_equal(new_atoms.arrays["radii"], [0.2]) diff --git a/zndraw/zndraw.py b/zndraw/zndraw.py index a0bf4bf27..60d7ba261 100644 --- a/zndraw/zndraw.py +++ b/zndraw/zndraw.py @@ -13,9 +13,9 @@ from redis import Redis from zndraw.base import Extension, ZnDrawBase +from zndraw.bonds import ASEComputeBonds from zndraw.draw import Geometry, Object3D from zndraw.utils import ASEConverter, call_with_retry, emit_with_retry -from zndraw.bonds import ASEComputeBonds log = logging.getLogger(__name__) @@ -87,7 +87,7 @@ class ZnDraw(ZnDrawBase): default_factory=datetime.datetime.now ) - bond_calculator: ASEComputeBonds|None = dataclasses.field( + bond_calculator: ASEComputeBonds | None = dataclasses.field( default_factory=ASEComputeBonds, repr=False ) @@ -180,7 +180,9 @@ def __setitem__( for i, val in zip(index, value): if not hasattr(val, "connectivity") and self.bond_calculator is not None: val.connectivity = self.bond_calculator.get_bonds(val) - data[i] = znjson.dumps(val, cls=znjson.ZnEncoder.from_converters([ASEConverter])) + data[i] = znjson.dumps( + val, cls=znjson.ZnEncoder.from_converters([ASEConverter]) + ) call_with_retry( self.socket, From b8bb8d3d825e0d64f7230b4ea70f234c291cdc13 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 16:50:58 +0200 Subject: [PATCH 13/17] fix insert --- zndraw/zndraw.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/zndraw/zndraw.py b/zndraw/zndraw.py index a0bf4bf27..d544ef210 100644 --- a/zndraw/zndraw.py +++ b/zndraw/zndraw.py @@ -212,6 +212,9 @@ def __delitem__(self, index: int | slice | list[int]): ) def insert(self, index: int, value: ase.Atoms): + if hasattr(value, "connectivity") and self.bond_calculator is not None: + value.connectivity = self.bond_calculator.get_bonds(value) + call_with_retry( self.socket, "room:frames:insert", From 81bb5f207d263dbde37a3dd833a940cc850cb13d Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 17:02:26 +0200 Subject: [PATCH 14/17] use scaled radii --- zndraw/utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/zndraw/utils.py b/zndraw/utils.py index ea20af07d..089b02e27 100644 --- a/zndraw/utils.py +++ b/zndraw/utils.py @@ -16,6 +16,7 @@ from ase.data import covalent_radii from ase.data.colors import jmol_colors from znjson import ConverterBase +import functools log = logging.getLogger(__name__) @@ -35,6 +36,15 @@ def rgb2hex(value): r, g, b = np.array(value * 255, dtype=int) return "#%02x%02x%02x" % (r, g, b) +@functools.lru_cache(maxsize=128) +def get_scaled_radii() -> np.ndarray: + """Scale down the covalent radii to visualize bonds better.""" + radii = covalent_radii + # shift the values such that they are in [0.3, 1.3] + radii = radii - np.min(radii) + radii = radii / np.max(radii) + radii = radii + 0.3 + return radii class ASEConverter(ConverterBase): """Encode/Decode datetime objects @@ -99,7 +109,8 @@ def encode(self, obj: ase.Atoms) -> ASEDict: ) if "radii" not in obj.arrays: - arrays["radii"] = [covalent_radii[number] for number in numbers] + # arrays["radii"] = [covalent_radii[number] for number in numbers] + arrays["radii"] = [get_scaled_radii()[number] for number in numbers] else: arrays["radii"] = ( obj.arrays["radii"].tolist() From 9888673919aaefd17e3bb67ae00137ddc35a7980 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Jun 2024 15:02:36 +0000 Subject: [PATCH 15/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- zndraw/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/zndraw/utils.py b/zndraw/utils.py index 089b02e27..ecf680db3 100644 --- a/zndraw/utils.py +++ b/zndraw/utils.py @@ -1,3 +1,4 @@ +import functools import importlib.util import json import logging @@ -16,7 +17,6 @@ from ase.data import covalent_radii from ase.data.colors import jmol_colors from znjson import ConverterBase -import functools log = logging.getLogger(__name__) @@ -36,6 +36,7 @@ def rgb2hex(value): r, g, b = np.array(value * 255, dtype=int) return "#%02x%02x%02x" % (r, g, b) + @functools.lru_cache(maxsize=128) def get_scaled_radii() -> np.ndarray: """Scale down the covalent radii to visualize bonds better.""" @@ -46,6 +47,7 @@ def get_scaled_radii() -> np.ndarray: radii = radii + 0.3 return radii + class ASEConverter(ConverterBase): """Encode/Decode datetime objects From da108a3d4d439e6ea02a79f0f9f938cb1a721421 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 17:05:01 +0200 Subject: [PATCH 16/17] rescaled radii --- tests/test_serializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 93b07fb94..25681ca0c 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -41,4 +41,4 @@ def test_exotic_atoms(): cls=znjson.ZnDecoder.from_converters([ASEConverter]), ) npt.assert_array_equal(new_atoms.arrays["colors"], ["#ff0000"]) - npt.assert_array_equal(new_atoms.arrays["radii"], [0.2]) + npt.assert_array_equal(new_atoms.arrays["radii"], [0.3]) From 3dbee9a5e4dcbad41165a69924679f9caf5fac32 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Sat, 15 Jun 2024 17:08:55 +0200 Subject: [PATCH 17/17] remove log --- app/src/App.tsx | 3 --- 1 file changed, 3 deletions(-) diff --git a/app/src/App.tsx b/app/src/App.tsx index 0f3931250..bb394883b 100644 --- a/app/src/App.tsx +++ b/app/src/App.tsx @@ -134,9 +134,6 @@ export default function App() { // if step changes useEffect(() => { socket.emit("room:frames:get", [step], (frames: Frames) => { - console.log(frames); - // map positions: numbers[][] to THREE.Vector3[] - for (const key in frames) { if (frames.hasOwnProperty(key)) { const frame: Frame = frames[key]["value"];