Skip to content

Commit 3eff1d4

Browse files
committed
opendft: add jupyter demo
1 parent 68725cc commit 3eff1d4

File tree

2 files changed

+656
-0
lines changed

2 files changed

+656
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stderr",
10+
"output_type": "stream",
11+
"text": [
12+
"/data2/haiyang/anaconda3/envs/QHBench/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13+
" from .autonotebook import tqdm as notebook_tqdm\n"
14+
]
15+
}
16+
],
17+
"source": [
18+
"import os\n",
19+
"import sys\n",
20+
"sys.path.insert(0, os.path.dirname(os.getcwd()))\n",
21+
"\n",
22+
"import torch\n",
23+
"from datasets import QH9Stable, QH9Dynamic"
24+
]
25+
},
26+
{
27+
"cell_type": "markdown",
28+
"metadata": {},
29+
"source": [
30+
"### Here is the statistics of the dataset"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": 2,
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"def get_hamiltonian_size(molecule_atoms):\n",
40+
" atom_mask_periodic_row1 = molecule_atoms <= 2\n",
41+
" atom_mask_periodic_row2 = molecule_atoms > 2\n",
42+
" num_orbitals = atom_mask_periodic_row2.sum() * 14 + (atom_mask_periodic_row1.sum()) * 2\n",
43+
" return num_orbitals\n",
44+
"\n",
45+
"\n",
46+
"def get_dataset_statistic(ori_dataset):\n",
47+
" statistic_info = {}\n",
48+
" dataset_split_name = ['train', 'val', 'test']\n",
49+
" for split_name in dataset_split_name:\n",
50+
" statistic_info[split_name] = {}\n",
51+
" dataset = ori_dataset[getattr(ori_dataset, f'{split_name}_mask')]\n",
52+
"\n",
53+
" all_num_nodes = [data.num_nodes for data in dataset]\n",
54+
" all_num_nodes = torch.tensor(all_num_nodes).float()\n",
55+
" num_node_mean, num_node_min, num_node_max, num_node_median = \\\n",
56+
" all_num_nodes.mean(), all_num_nodes.min(), all_num_nodes.max(), all_num_nodes.median()\n",
57+
"\n",
58+
" all_electronics = torch.tensor([data.atoms.sum() for data in dataset]).float()\n",
59+
" num_electronics_mean, num_electronics_min, num_electronics_max, num_electronics_median = \\\n",
60+
" all_electronics.mean(), all_electronics.min(), all_electronics.max(), all_electronics.median()\n",
61+
"\n",
62+
" all_hamiltonian_matrix_size = [get_hamiltonian_size(data.atoms) for data in dataset]\n",
63+
" all_hamiltonian_matrix_size = torch.tensor(all_hamiltonian_matrix_size).float()\n",
64+
" hamiltonian_size_mean, hamiltonian_size_min, hamiltonian_size_max, hamiltonian_size_median = \\\n",
65+
" all_hamiltonian_matrix_size.mean(), all_hamiltonian_matrix_size.min(), \\\n",
66+
" all_hamiltonian_matrix_size.max(), all_hamiltonian_matrix_size.median()\n",
67+
"\n",
68+
" statistic_info[split_name]['num_node_mean'], statistic_info[split_name]['num_node_min'], \\\n",
69+
" statistic_info[split_name]['num_node_max'], statistic_info[split_name]['num_node_median'] = \\\n",
70+
" num_node_mean.item(), num_node_min.item(), num_node_max.item(), num_node_median.item()\n",
71+
"\n",
72+
" statistic_info[split_name]['num_electronics_mean'], statistic_info[split_name]['num_electronics_min'], \\\n",
73+
" statistic_info[split_name]['num_electronics_max'], statistic_info[split_name]['num_electronics_median'] = \\\n",
74+
" num_electronics_mean.item(), num_electronics_min.item(), num_electronics_max.item(), num_electronics_median.item()\n",
75+
"\n",
76+
" statistic_info[split_name]['hamiltonian_size_mean'], statistic_info[split_name]['hamiltonian_size_min'], \\\n",
77+
" statistic_info[split_name]['hamiltonian_size_max'], statistic_info[split_name]['hamiltonian_size_median'], \\\n",
78+
" = hamiltonian_size_mean.item(), hamiltonian_size_min.item(), hamiltonian_size_max.item(), hamiltonian_size_median.item()\n",
79+
"\n",
80+
" return statistic_info"
81+
]
82+
},
83+
{
84+
"cell_type": "code",
85+
"execution_count": 4,
86+
"metadata": {},
87+
"outputs": [],
88+
"source": [
89+
"dataset_stable_random = QH9Stable(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), split='random')\n",
90+
"dataset_stable_random_statistic= get_dataset_statistic(dataset_stable_random)"
91+
]
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": 3,
96+
"metadata": {},
97+
"outputs": [
98+
{
99+
"name": "stderr",
100+
"output_type": "stream",
101+
"text": [
102+
"Processing...\n",
103+
" 0%| | 113M/30.5G [00:19<22:44, 22.3MB/s]"
104+
]
105+
}
106+
],
107+
"source": [
108+
"dataset_stable_ood = QH9Stable(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), split='size_ood')\n",
109+
"dataset_stable_ood_statistic = get_dataset_statistic(dataset_stable_ood)"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": null,
115+
"metadata": {},
116+
"outputs": [],
117+
"source": [
118+
"dataset_dynamic_geo_100k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='100k', split='geometry')\n",
119+
"dataset_dynamic_geo_100k_statistic = get_dataset_statistic(dataset_dynamic_geo_100k)"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": null,
125+
"metadata": {},
126+
"outputs": [],
127+
"source": [
128+
"dataset_dynamic_mol_100k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='100k', split='mol')\n",
129+
"dataset_dynamic_mol_100k_statistic = get_dataset_statistic(dataset_dynamic_mol_100k)"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": 3,
135+
"metadata": {},
136+
"outputs": [
137+
{
138+
"data": {
139+
"text/plain": [
140+
"{'train': {'num_node_mean': 18.03936004638672,\n",
141+
" 'num_node_min': 7.0,\n",
142+
" 'num_node_max': 27.0,\n",
143+
" 'num_node_median': 18.0,\n",
144+
" 'num_electronics_mean': 65.87992095947266,\n",
145+
" 'num_electronics_min': 24.0,\n",
146+
" 'num_electronics_max': 74.0,\n",
147+
" 'num_electronics_median': 66.0,\n",
148+
" 'hamiltonian_size_mean': 141.5890655517578,\n",
149+
" 'hamiltonian_size_min': 54.0,\n",
150+
" 'hamiltonian_size_max': 162.0,\n",
151+
" 'hamiltonian_size_median': 144.0},\n",
152+
" 'val': {'num_node_mean': 18.03936004638672,\n",
153+
" 'num_node_min': 7.0,\n",
154+
" 'num_node_max': 27.0,\n",
155+
" 'num_node_median': 18.0,\n",
156+
" 'num_electronics_mean': 65.87992095947266,\n",
157+
" 'num_electronics_min': 24.0,\n",
158+
" 'num_electronics_max': 74.0,\n",
159+
" 'num_electronics_median': 66.0,\n",
160+
" 'hamiltonian_size_mean': 141.5890655517578,\n",
161+
" 'hamiltonian_size_min': 54.0,\n",
162+
" 'hamiltonian_size_max': 162.0,\n",
163+
" 'hamiltonian_size_median': 144.0},\n",
164+
" 'test': {'num_node_mean': 18.03936004638672,\n",
165+
" 'num_node_min': 7.0,\n",
166+
" 'num_node_max': 27.0,\n",
167+
" 'num_node_median': 18.0,\n",
168+
" 'num_electronics_mean': 65.87992095947266,\n",
169+
" 'num_electronics_min': 24.0,\n",
170+
" 'num_electronics_max': 74.0,\n",
171+
" 'num_electronics_median': 66.0,\n",
172+
" 'hamiltonian_size_mean': 141.5890655517578,\n",
173+
" 'hamiltonian_size_min': 54.0,\n",
174+
" 'hamiltonian_size_max': 162.0,\n",
175+
" 'hamiltonian_size_median': 144.0}}"
176+
]
177+
},
178+
"execution_count": 3,
179+
"metadata": {},
180+
"output_type": "execute_result"
181+
}
182+
],
183+
"source": [
184+
"dataset_dynamic_geo_300k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='300k', split='geometry')\n",
185+
"dataset_dynamic_geo_300k_statistic = get_dataset_statistic(dataset_dynamic_geo_300k)\n",
186+
"dataset_dynamic_geo_300k_statistic"
187+
]
188+
},
189+
{
190+
"cell_type": "code",
191+
"execution_count": 3,
192+
"metadata": {},
193+
"outputs": [
194+
{
195+
"data": {
196+
"text/plain": [
197+
"{'train': {'num_node_mean': 18.015846252441406,\n",
198+
" 'num_node_min': 7.0,\n",
199+
" 'num_node_max': 27.0,\n",
200+
" 'num_node_median': 18.0,\n",
201+
" 'num_electronics_mean': 65.91242980957031,\n",
202+
" 'num_electronics_min': 24.0,\n",
203+
" 'num_electronics_max': 74.0,\n",
204+
" 'num_electronics_median': 66.0,\n",
205+
" 'hamiltonian_size_mean': 141.58465576171875,\n",
206+
" 'hamiltonian_size_min': 54.0,\n",
207+
" 'hamiltonian_size_max': 162.0,\n",
208+
" 'hamiltonian_size_median': 144.0},\n",
209+
" 'val': {'num_node_mean': 18.153846740722656,\n",
210+
" 'num_node_min': 10.0,\n",
211+
" 'num_node_max': 25.0,\n",
212+
" 'num_node_median': 18.0,\n",
213+
" 'num_electronics_mean': 65.71237182617188,\n",
214+
" 'num_electronics_min': 34.0,\n",
215+
" 'num_electronics_max': 74.0,\n",
216+
" 'num_electronics_median': 66.0,\n",
217+
" 'hamiltonian_size_mean': 141.17726135253906,\n",
218+
" 'hamiltonian_size_min': 72.0,\n",
219+
" 'hamiltonian_size_max': 158.0,\n",
220+
" 'hamiltonian_size_median': 144.0},\n",
221+
" 'test': {'num_node_mean': 18.112957000732422,\n",
222+
" 'num_node_min': 9.0,\n",
223+
" 'num_node_max': 25.0,\n",
224+
" 'num_node_median': 18.0,\n",
225+
" 'num_electronics_mean': 65.7873764038086,\n",
226+
" 'num_electronics_min': 50.0,\n",
227+
" 'num_electronics_max': 72.0,\n",
228+
" 'num_electronics_median': 66.0,\n",
229+
" 'hamiltonian_size_mean': 142.03321838378906,\n",
230+
" 'hamiltonian_size_min': 102.0,\n",
231+
" 'hamiltonian_size_max': 158.0,\n",
232+
" 'hamiltonian_size_median': 144.0}}"
233+
]
234+
},
235+
"execution_count": 3,
236+
"metadata": {},
237+
"output_type": "execute_result"
238+
}
239+
],
240+
"source": [
241+
"dataset_dynamic_mol_300k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='300k', split='mol')\n",
242+
"dataset_dynamic_mol_300k_statistic = get_dataset_statistic(dataset_dynamic_mol_300k)\n",
243+
"dataset_dynamic_mol_300k_statistic"
244+
]
245+
},
246+
{
247+
"cell_type": "code",
248+
"execution_count": 9,
249+
"metadata": {},
250+
"outputs": [
251+
{
252+
"data": {
253+
"text/plain": [
254+
"{'train': {'num_node_mean': 18.023332595825195,\n",
255+
" 'num_node_min': 3.0,\n",
256+
" 'num_node_max': 29.0,\n",
257+
" 'num_node_median': 18.0,\n",
258+
" 'num_electronics_mean': 65.89612579345703,\n",
259+
" 'num_electronics_min': 10.0,\n",
260+
" 'num_electronics_max': 74.0,\n",
261+
" 'num_electronics_median': 66.0,\n",
262+
" 'hamiltonian_size_mean': 141.5970001220703,\n",
263+
" 'hamiltonian_size_min': 18.0,\n",
264+
" 'hamiltonian_size_max': 166.0,\n",
265+
" 'hamiltonian_size_median': 144.0},\n",
266+
" 'val': {'num_node_mean': 18.026752471923828,\n",
267+
" 'num_node_min': 6.0,\n",
268+
" 'num_node_max': 29.0,\n",
269+
" 'num_node_median': 18.0,\n",
270+
" 'num_electronics_mean': 65.90185546875,\n",
271+
" 'num_electronics_min': 18.0,\n",
272+
" 'num_electronics_max': 74.0,\n",
273+
" 'num_electronics_median': 66.0,\n",
274+
" 'hamiltonian_size_mean': 141.6219482421875,\n",
275+
" 'hamiltonian_size_min': 36.0,\n",
276+
" 'hamiltonian_size_max': 166.0,\n",
277+
" 'hamiltonian_size_median': 144.0},\n",
278+
" 'test': {'num_node_mean': 18.035158157348633,\n",
279+
" 'num_node_min': 4.0,\n",
280+
" 'num_node_max': 29.0,\n",
281+
" 'num_node_median': 18.0,\n",
282+
" 'num_electronics_mean': 65.8647232055664,\n",
283+
" 'num_electronics_min': 24.0,\n",
284+
" 'num_electronics_max': 74.0,\n",
285+
" 'num_electronics_median': 66.0,\n",
286+
" 'hamiltonian_size_mean': 141.55824279785156,\n",
287+
" 'hamiltonian_size_min': 48.0,\n",
288+
" 'hamiltonian_size_max': 166.0,\n",
289+
" 'hamiltonian_size_median': 144.0}}"
290+
]
291+
},
292+
"execution_count": 9,
293+
"metadata": {},
294+
"output_type": "execute_result"
295+
}
296+
],
297+
"source": [
298+
"dataset_stable_random_statistic"
299+
]
300+
},
301+
{
302+
"cell_type": "code",
303+
"execution_count": 3,
304+
"metadata": {},
305+
"outputs": [],
306+
"source": [
307+
"dataset_stable_ood = QH9Stable(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), split='size_ood')\n",
308+
"dataset_stable_ood_statistic = get_dataset_statistic(dataset_stable_ood)\n",
309+
"dataset_stable_ood_statistic"
310+
]
311+
},
312+
{
313+
"cell_type": "code",
314+
"execution_count": 16,
315+
"metadata": {},
316+
"outputs": [
317+
{
318+
"data": {
319+
"text/plain": [
320+
"array([], dtype=int64)"
321+
]
322+
},
323+
"execution_count": 16,
324+
"metadata": {},
325+
"output_type": "execute_result"
326+
}
327+
],
328+
"source": [
329+
"dataset_stable_ood.train_mask"
330+
]
331+
},
332+
{
333+
"cell_type": "code",
334+
"execution_count": null,
335+
"metadata": {},
336+
"outputs": [],
337+
"source": []
338+
}
339+
],
340+
"metadata": {
341+
"kernelspec": {
342+
"display_name": "QHBench",
343+
"language": "python",
344+
"name": "python3"
345+
},
346+
"language_info": {
347+
"codemirror_mode": {
348+
"name": "ipython",
349+
"version": 3
350+
},
351+
"file_extension": ".py",
352+
"mimetype": "text/x-python",
353+
"name": "python",
354+
"nbconvert_exporter": "python",
355+
"pygments_lexer": "ipython3",
356+
"version": "3.8.18"
357+
}
358+
},
359+
"nbformat": 4,
360+
"nbformat_minor": 2
361+
}

0 commit comments

Comments
 (0)