Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use znjson instead of znframe #481

Merged
merged 19 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions app/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,15 @@ export default function App() {
// if step changes
useEffect(() => {
socket.emit("room:frames:get", [step], (frames: Frames) => {
// map positions: numbers[][] to THREE.Vector3[]

for (const key in frames) {
if (frames.hasOwnProperty(key)) {
const frame = frames[key];
const frame: Frame = frames[key]["value"];
frame.positions = frame.positions.map(
(position) =>
new THREE.Vector3(position[0], position[1], position[2]),
) as THREE.Vector3[];
}
setCurrentFrame(frames[step]);
setCurrentFrame(frames[step]["value"]);
setNeedsUpdate(false); // rename this to something more descriptive
}
});
Expand Down
2 changes: 1 addition & 1 deletion app/src/components/particles.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export interface Frame {
``;

export interface Frames {
[key: number]: Frame;
[key: number]: { _type: string; value: Frame };
}

export const Player = ({
Expand Down
28 changes: 5 additions & 23 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ mdanalysis = {version = "^2", optional = true}
tidynamics = {version = "^1", optional = true}
rdkit2ase = {version = "^0.1", optional = true}
zntrack = {version = "^0.7.3", optional = true}
znframe = "^0.1"
celery = "^5"
sqlalchemy = "^2"
redis = "^5"
splines = "^0.3"
znsocket = "^0.1.6"
lazy-loader = "^0.4"
znjson = "^0.2.3"


[tool.poetry.group.dev.dependencies]
Expand Down
44 changes: 44 additions & 0 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import ase
import numpy.testing as npt
import pytest
import znjson
from ase.calculators.singlepoint import SinglePointCalculator

from zndraw.utils import ASEConverter


def test_ase_converter(s22):
s22[0].connectivity = [[0, 1, 1], [1, 2, 1], [2, 3, 1]]
s22[3].calc = SinglePointCalculator(s22[3])
s22[3].calc.results = {"energy": 0.0, "predicted_energy": 1.0}
s22[4].info = {"key": "value"}

structures_json = znjson.dumps(
s22, cls=znjson.ZnEncoder.from_converters([ASEConverter])
)
structures = znjson.loads(
structures_json, cls=znjson.ZnDecoder.from_converters([ASEConverter])
)
for s1, s2 in zip(s22, structures):
assert s1 == s2

npt.assert_array_equal(structures[0].connectivity, [[0, 1, 1], [1, 2, 1], [2, 3, 1]])
with pytest.raises(AttributeError):
_ = structures[1].connectivity

assert structures[3].calc.results == {"energy": 0.0, "predicted_energy": 1.0}

assert "colors" in structures[0].arrays
assert "radii" in structures[0].arrays

assert structures[4].info == {"key": "value"}


def test_exotic_atoms():
atoms = ase.Atoms("X", positions=[[0, 0, 0]])
new_atoms = znjson.loads(
znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
cls=znjson.ZnDecoder.from_converters([ASEConverter]),
)
npt.assert_array_equal(new_atoms.arrays["colors"], ["#ff0000"])
npt.assert_array_equal(new_atoms.arrays["radii"], [0.3])
8 changes: 6 additions & 2 deletions zndraw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import ase
import numpy as np
import splines
import znframe
import znjson
import znsocket
from flask import current_app, session
from pydantic import BaseModel, Field, create_model
from redis import Redis

from zndraw.utils import ASEConverter

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -42,7 +44,9 @@ def get_atoms() -> ase.Atoms:
lst = znsocket.List(r, key)
try:
frame_json = lst[int(step)]
return znframe.Frame.from_json(frame_json).to_atoms()
return znjson.loads(
frame_json, cls=znjson.ZnDecoder.from_converters([ASEConverter])
)
except TypeError:
# step is None
return ase.Atoms()
Expand Down
5 changes: 3 additions & 2 deletions zndraw/bonds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def remove_edge(graph, atom_1, atom_2):
except NetworkXError:
pass

def get_bonds(self, atoms: ase.Atoms):
graph = atoms.connectivity
def get_bonds(self, atoms: ase.Atoms, graph: nx.Graph = None):
if graph is None:
graph = self.build_graph(atoms)
bonds = []
for edge in graph.edges:
bonds.append((edge[0], edge[1], graph.edges[edge]["weight"]))
Expand Down
8 changes: 4 additions & 4 deletions zndraw/modify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
from ase.data import chemical_symbols
from pydantic import Field
from znframe.frame import get_radius

from zndraw.base import Extension, MethodsCollection

Expand Down Expand Up @@ -50,13 +49,13 @@ class Connect(UpdateScene):
"""Create guiding curve between selected atoms."""

def run(self, vis: "ZnDraw", **kwargs) -> None:
atoms = vis.atoms
atom_ids = vis.selection
atom_positions = vis.atoms.get_positions()
atom_numbers = vis.atoms.numbers[atom_ids]
camera_position = np.array(vis.camera["position"])[None, :] # 1,3

new_points = atom_positions[atom_ids] # N, 3
radii: np.ndarray = get_radius(atom_numbers)[0][:, None] # N, 1
radii: np.ndarray = atoms.arrays["radii"][atom_ids][:, None]
direction = camera_position - new_points
direction /= np.linalg.norm(direction, axis=1, keepdims=True)
new_points += direction * radii
Expand Down Expand Up @@ -118,7 +117,8 @@ def run(self, vis: "ZnDraw", **kwargs) -> None:
else:
for idx, atom_id in enumerate(sorted(atom_ids)):
atoms.pop(atom_id - idx) # we remove the atom and shift the index
del atoms.connectivity
if hasattr(atoms, "connectivity"):
del atoms.connectivity
vis.append(atoms)
vis.selection = []
vis.step += 1
Expand Down
8 changes: 6 additions & 2 deletions zndraw/scene.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import enum

import ase
import znframe
import znjson
import znsocket
from flask import current_app, session
from pydantic import BaseModel, Field
from redis import Redis

from zndraw.utils import ASEConverter


class Material(str, enum.Enum):
MeshBasicMaterial = "MeshBasicMaterial"
Expand Down Expand Up @@ -74,7 +76,9 @@ def _get_atoms() -> ase.Atoms:
lst = znsocket.List(r, key)
try:
frame_json = lst[int(step)]
return znframe.Frame.from_json(frame_json).to_atoms()
return znjson.loads(
frame_json, cls=znjson.ZnDecoder.from_converters([ASEConverter])
)
except TypeError:
# step is None
return ase.Atoms()
Expand Down
12 changes: 9 additions & 3 deletions zndraw/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import ase.io
import socketio.exceptions
import tqdm
import znframe
import znjson
import znsocket
from celery import shared_task
from flask import current_app

from zndraw.base import FileIO
from zndraw.bonds import ASEComputeBonds
from zndraw.utils import ASEConverter

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -37,6 +39,7 @@ def read_file(fileio: dict) -> None:
r.delete("room:default:frames")

lst = znsocket.List(r, "room:default:frames")
bonds_calculator = ASEComputeBonds()

if file_io.name is None:

Expand Down Expand Up @@ -77,8 +80,11 @@ def _generator():
break
if file_io.step and idx % file_io.step != 0:
continue
frame = znframe.Frame.from_atoms(atoms)
lst.append(frame.to_json())
if not hasattr(atoms, "connectivity"):
atoms.connectivity = bonds_calculator.get_bonds(atoms)
lst.append(
znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter]))
)
if idx == 0:
try:
io.connect(current_app.config["SERVER_URL"], wait_timeout=10)
Expand Down
Loading
Loading