-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy path__init__.py
243 lines (194 loc) · 6.59 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import logging
import typing as t
import ase.io
import socketio.exceptions
import tqdm
import znjson
import znsocket
from celery import shared_task
from flask import current_app
from zndraw.base import FileIO
from zndraw.bonds import ASEComputeBonds
from zndraw.utils import ASEConverter
log = logging.getLogger(__name__)
@shared_task
def run_znsocket_server(port: int) -> None:
# Does not work with eventlet enabled!
znsocket.Server(port=port).run()
log.critical("ZnSocket server closed.")
@shared_task
def read_file(fileio: dict) -> None:
file_io = FileIO(**fileio)
# r = Redis(host="localhost", port=6379, db=0, decode_responses=True)
r = current_app.extensions["redis"]
io = socketio.Client()
# r = znsocket.Client("http://127.0.0.1:5000")
# TODO: make everyone join room main
# send update here to everyone in room, because this is only called once in the beginning
# chain this with compute_bonds. So this will load much faster
r.delete("room:default:frames")
lst = znsocket.List(r, "room:default:frames")
bonds_calculator = ASEComputeBonds()
if file_io.name is None:
def _generator():
yield ase.Atoms()
generator = _generator()
elif file_io.remote is not None:
node_name, attribute = file_io.name.split(".", 1)
try:
import zntrack
node = zntrack.from_rev(node_name, remote=file_io.remote, rev=file_io.rev)
generator = getattr(node, attribute)
except ImportError as err:
raise ImportError(
"You need to install ZnTrack to use the remote feature (or `pip install zndraw[all]`)."
) from err
elif file_io.name.endswith((".h5", ".hdf5", ".h5md")):
try:
import znh5md
reader = znh5md.ASEH5MD(file_io.name)
generator = reader.get_atoms_list()
except ImportError as err:
raise ImportError(
"You need to install ZnH5MD to use the remote feature (or `pip install zndraw[all]`)."
) from err
else:
generator = ase.io.iread(file_io.name)
generator: t.Iterable[ase.Atoms]
for idx, atoms in tqdm.tqdm(enumerate(generator)):
if file_io.start and idx < file_io.start:
continue
if file_io.stop and idx >= file_io.stop:
break
if file_io.step and idx % file_io.step != 0:
continue
if not hasattr(atoms, "connectivity"):
atoms.connectivity = bonds_calculator.get_bonds(atoms)
lst.append(
znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter]))
)
if idx == 0:
try:
io.connect(current_app.config["SERVER_URL"], wait_timeout=10)
io.emit("room:all:frames:refresh", [0])
except socketio.exceptions.ConnectionError:
pass
while True:
try:
if not io.connected:
io.connect(current_app.config["SERVER_URL"], wait_timeout=10)
# updates len after all frames are loaded
io.emit("room:all:frames:refresh", [idx])
break
except socketio.exceptions.ConnectionError:
pass
io.sleep(1)
io.disconnect()
@shared_task
def run_modifier(room, data: dict) -> None:
from zndraw.modify import Modifier
from zndraw.zndraw import ZnDrawLocal
vis = ZnDrawLocal(
r=current_app.extensions["redis"],
url=current_app.config["SERVER_URL"],
token=room,
)
vis.socket.emit("room:modifier:queue", 0)
try:
modifier = Modifier(**data)
modifier.run(vis)
except Exception as e:
vis.log(str(e))
finally:
vis.socket.emit("room:modifier:queue", -1)
# wait and then disconnect
vis.socket.sleep(1)
vis.socket.disconnect()
@shared_task
def run_selection(room, data: dict) -> None:
from zndraw.selection import Selection
from zndraw.zndraw import ZnDrawLocal
vis = ZnDrawLocal(
r=current_app.extensions["redis"],
url=current_app.config["SERVER_URL"],
token=room,
)
vis.socket.emit("room:selection:queue", 0)
try:
selection = Selection(**data)
selection.run(vis)
finally:
vis.socket.emit("room:selection:queue", -1)
# wait and then disconnect
vis.socket.sleep(1)
vis.socket.disconnect()
@shared_task
def run_analysis(room, data: dict) -> None:
from zndraw.analyse import Analysis
from zndraw.zndraw import ZnDrawLocal
vis = ZnDrawLocal(
r=current_app.extensions["redis"],
url=current_app.config["SERVER_URL"],
token=room,
)
vis.socket.emit("room:analysis:queue", 0)
try:
analysis = Analysis(**data)
analysis.run(vis)
except Exception as e:
vis.log(str(e))
finally:
vis.socket.emit("room:analysis:queue", -1)
# wait and then disconnect
vis.socket.sleep(1)
vis.socket.disconnect()
@shared_task
def run_geometry(room, data: dict) -> None:
from zndraw.draw import Geometry
from zndraw.zndraw import ZnDrawLocal
vis = ZnDrawLocal(
r=current_app.extensions["redis"],
url=current_app.config["SERVER_URL"],
token=room,
)
vis.socket.emit("room:geometry:queue", 0)
try:
geom = Geometry(**data)
# TODO: set the position / rotation / scale
geom.run(vis)
finally:
vis.socket.emit("room:geometry:queue", -1)
# wait and then disconnect
vis.socket.sleep(1)
vis.socket.disconnect()
@shared_task
def run_upload_file(room, data: dict):
from io import StringIO
import ase.io
from zndraw.zndraw import ZnDrawLocal
vis = ZnDrawLocal(
r=current_app.extensions["redis"],
url=current_app.config["SERVER_URL"],
token=room,
)
vis.log(f"Uploading {data['filename']}")
format = data["filename"].split(".")[-1]
format = format if format != "xyz" else "extxyz"
if format == "h5":
raise ValueError("H5MD format not supported for uploading yet")
stream = StringIO(bytes(data["content"]).decode("utf-8"))
atoms_list = list(ase.io.iread(stream, format=format))
if len(atoms_list) == 1 and len(vis.points) != 0:
scene = vis.atoms
atoms = atoms_list[0]
if hasattr(scene, "connectivity"):
del scene.connectivity
for point in vis.points:
atoms.positions -= atoms.get_center_of_mass() - point
scene += atoms
vis.append(scene)
else:
vis.extend(atoms_list)
vis.step = len(vis) - 1
vis.socket.sleep(1)
vis.socket.disconnect()