Skip to content

Commit

Permalink
++
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipDeegan committed Sep 12, 2024
1 parent 04238f6 commit 9049ef5
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 41 deletions.
3 changes: 1 addition & 2 deletions pyphare/pyphare/pharesee/hierarchy/fromh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
particle_files_patterns = ("domain", "patchGhost", "levelGhost")


def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"]):
def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"], hier=None):
time = format_timestamp(time)
hier = None
path = Path(filepath)
for h5 in path.glob("*.h5"):
if h5.parent == path and h5.stem not in exclude:
Expand Down
38 changes: 25 additions & 13 deletions pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from copy import deepcopy
import numpy as np

from typing import Any

from .hierarchy import PatchHierarchy, format_timestamp

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
pyphare.pharesee.hierarchy.hierarchy
begins an import cycle.
from .patchdata import FieldData, ParticleData
from .patchlevel import PatchLevel
Expand All @@ -10,7 +12,6 @@
from ...core.gridlayout import GridLayout
from ...core.phare_utilities import listify
from ...core.phare_utilities import refinement_ratio
from pyphare.pharesee import particles as mparticles


field_qties = {
Expand Down Expand Up @@ -562,15 +563,24 @@ def _compute_scalardiv(patch_datas, **kwargs):
class EqualityReport:
ok: bool
reason: str
ref: Any = None
cmp: Any = None

def __bool__(self):
return self.ok

def __repr__(self):
return self.reason

def __post_init__(self):
not_nones = [a is not None for a in [self.ref, self.cmp]]
if all(not_nones):
assert id(self.ref) != id(self.cmp)
else:
assert not any(not_nones)


def hierarchy_compare(this, that):
def hierarchy_compare(this, that, atol=1e-16):
if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy):
return EqualityReport(False, "class type mismatch")

Expand All @@ -596,24 +606,26 @@ def hierarchy_compare(this, that):
patch_cmp = patch_level_cmp.patches[patch_idx]

if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys():
print(list(patch_ref.patch_datas.keys()))
print(list(patch_cmp.patch_datas.keys()))
return EqualityReport(False, "data keys mismatch")

for patch_data_key in patch_ref.patch_datas.keys():
patch_data_ref = patch_ref.patch_datas[patch_data_key]
patch_data_cmp = patch_cmp.patch_datas[patch_data_key]

if patch_data_cmp != patch_data_ref:
msg = f"data mismatch: {patch_data_key} {type(patch_data_cmp).__name__} {type(patch_data_ref).__name__}"
return EqualityReport(False, msg)
if not patch_data_cmp.compare(patch_data_ref, atol=atol):
msg = f"data mismatch: {type(patch_data_ref).__name__} {patch_data_key}"
return EqualityReport(
False, msg, patch_data_cmp, patch_data_ref
)

return EqualityReport(True, "OK")


def single_patch_for_LO(hier, qties=None):
def single_patch_for_LO(hier, qties=None, skip=None):
def _skip(qty):
return qties is not None and qty not in qties
return (qties is not None and qty not in qties) or (
skip is not None and qty in skip
)

cier = deepcopy(hier)
sim = hier.sim
Expand All @@ -633,22 +645,22 @@ def _skip(qty):
layout, v.field_name, None, centering=v.centerings
)
l0_pds[k].dataset = np.zeros(l0_pds[k].size)
patch_box = hier.level(0, t).patches[0].box
l0_pds[k][patch_box] = v[patch_box]

elif isinstance(v, ParticleData):
l0_pds[k] = deepcopy(v)
else:
raise RuntimeError("unexpected state")

for patch in hier.level(0, t).patches:
for patch in hier.level(0, t).patches[1:]:
for k, v in patch.patch_datas.items():
if _skip(k):
continue
if isinstance(v, FieldData):
l0_pds[k][patch.box] = v[patch.box]
elif isinstance(v, ParticleData):
l0_pds[k].dataset = mparticles.aggregate(
[l0_pds[k].dataset, v.dataset]
)
l0_pds[k].dataset.add(v.dataset)
else:
raise RuntimeError("unexpected state")
return cier
18 changes: 13 additions & 5 deletions pyphare/pyphare/pharesee/hierarchy/patchdata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np

from ...core.phare_utilities import deep_copy, fp_any_all_close
from ...core import phare_utilities as phut

# deep_copy, fp_any_all_close, assert_fp_any_all_close
from ...core import box as boxm
from ...core.box import Box

Expand All @@ -24,7 +26,7 @@ def __init__(self, layout, quantity):

def __deepcopy__(self, memo):
no_copy_keys = ["dataset"] # do not copy these things
return deep_copy(self, memo, no_copy_keys)
return phut.deep_copy(self, memo, no_copy_keys)


class FieldData(PatchData):
Expand Down Expand Up @@ -81,10 +83,12 @@ def __repr__(self):
return self.__str__()

def compare(self, that, atol=1e-16):
return fp_any_all_close(self.dataset[:], that.dataset[:], atol)
return self.field_name == that.field_name and phut.fp_any_all_close(
self.dataset[:], that.dataset[:], atol=atol
)

def __eq__(self, that):
return self.field_name == that.field_name and self.compare(that)
return self.compare(that)

def __ne__(self, that):
return not (self == that)
Expand Down Expand Up @@ -228,5 +232,9 @@ def __getitem__(self, box):
def size(self):
return self.dataset.size()

def __eq__(self, that):
def compare(self, that, *args, **kwargs):
"""args/kwargs may include atol for consistency with field::compare"""
return self.name == that.name and self.dataset == that.dataset

def __eq__(self, that):
return self.compare(that)
10 changes: 9 additions & 1 deletion pyphare/pyphare/pharesee/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def size(self):

def __eq__(self, that):
if isinstance(that, Particles):
if self.size() != that.size():
print(
f"particles.py:Particles::eq size diff: {self.size()} != {that.size()}"
)
return False
# fails on OSX for some reason
set_check = set(self.as_tuples()) == set(that.as_tuples())
if set_check:
Expand All @@ -88,9 +93,12 @@ def __eq__(self, that):
print(f"particles.py:Particles::eq failed with: {ex}")
print_trace()
return False

print(f"particles.py:Particles::eq bad type: {type(that)}")
return False

def __ne__(self, that):
return not (self == that)

def select(self, box, box_type="cell"):
"""
select particles from the given box
Expand Down
9 changes: 6 additions & 3 deletions src/amr/data/initializers/samrai_hdf5_initializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,18 @@ void SamraiH5Interface<GridLayout>::populate_from(std::string const& dir, int co
int const& mpi_size,
std::string const& field_name)
{
if (restart_files.size()) // executed per pop, but we only need to run this once
return;

for (int rank = 0; rank < mpi_size; ++rank)
{
auto const hdf5_filepath = getRestartFileFullPath(dir, idx, mpi_size, rank);
auto& h5File = *restart_files.emplace_back(std::make_unique<SamraiHDF5File>(hdf5_filepath));
for (auto const& group : h5File.scan_for_groups({"level_0000", field_name}))
{
auto const em_path = group.substr(0, group.rfind("/"));
h5File.patches.emplace_back(h5File.getBoxFromPath(em_path + "/d_box"),
em_path.substr(0, em_path.rfind("/")));
auto const field_path = group.substr(0, group.rfind("/"));
auto const& field_box = h5File.getBoxFromPath(field_path + "/d_box");
h5File.patches.emplace_back(field_box, field_path.substr(0, field_path.rfind("/")));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ void SamraiHDF5ParticleInitializer<ParticleArray, GridLayout>::loadParticles(
{
using Packer = core::ParticlePacker<ParticleArray::dimension>;

auto const& overlaps
= SamraiH5Interface<GridLayout>::INSTANCE().box_intersections(layout.AMRBox());
auto const& dest_box = layout.AMRBox();
auto const& overlaps = SamraiH5Interface<GridLayout>::INSTANCE().box_intersections(dest_box);

for (auto const& [overlap_box, h5FilePtr, pdataptr] : overlaps)
{
auto& h5File = *h5FilePtr;
auto& pdata = *pdataptr;
auto& h5File = *h5FilePtr;
auto& pdata = *pdataptr;

std::string const poppath = pdata.base_path + "/" + popname + "##default/domainParticles_";
core::ContiguousParticles<ParticleArray::dimension> soa{0};

Expand All @@ -58,14 +60,12 @@ void SamraiHDF5ParticleInitializer<ParticleArray, GridLayout>::loadParticles(
}

for (std::size_t i = 0; i < soa.size(); ++i)
if (auto const p = soa.copy(i); core::isIn(core::Point{p.iCell}, overlap_box))
if (auto const p = soa.copy(i); core::isIn(core::Point{p.iCell}, dest_box))
particles.push_back(p);
}
}




} // namespace PHARE::amr


Expand Down
34 changes: 24 additions & 10 deletions tests/simulator/test_init_from_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,35 @@
import copy
import unittest
import subprocess
import numpy as np
import pyphare.pharein as ph

from pyphare.core import phare_utilities as phut
from pyphare.simulator.simulator import Simulator
from pyphare.pharesee.hierarchy.patchdata import FieldData, ParticleData
from pyphare.pharesee.hierarchy.fromh5 import get_all_available_quantities_from_h5
from pyphare.pharesee.hierarchy.hierarchy import format_timestamp
from pyphare.pharesee.hierarchy.hierarchy_utils import single_patch_for_LO
from pyphare.pharesee.hierarchy.hierarchy_utils import hierarchy_compare
from tests.simulator import SimulatorTest, test_restarts
from tests.diagnostic import dump_all_diags


timestep = 0.001
time_step_nbr = 1
time_step = 0.001
time_step_nbr = 5
final_time = time_step_nbr * time_step
first_mpi_size = 4
ppc = 100
cells = 200
first_out = "phare_outputs/reinit/first"
secnd_out = "phare_outputs/reinit/secnd"
timestamps = [0]
restart_idx = Z = 0
# timestamps = [0,time_step]
timestamps = np.arange(0, final_time + time_step, time_step)
restart_idx = Z = 2
simInitArgs = dict(
largest_patch_size=100,
time_step_nbr=time_step_nbr,
time_step=timestep,
time_step=time_step,
cells=cells,
dl=0.3,
init_options=dict(dir=f"{first_out}/00000.00{Z}00", mpi_size=first_mpi_size),
Expand Down Expand Up @@ -59,13 +65,22 @@ def test_reinit(self):
sim = ph.Simulation(**copy.deepcopy(simInitArgs))
setup_model(sim)
Simulator(sim).run().reset()
datahier0 = get_all_available_quantities_from_h5(first_out, timestamps[0])
datahier1 = get_all_available_quantities_from_h5(secnd_out, timestamps[0])
fidx, sidx = 2, 0
datahier0 = get_all_available_quantities_from_h5(first_out, timestamps[fidx])
datahier0.time_hier = { # swap times
format_timestamp(timestamps[sidx]): datahier0.time_hier[
format_timestamp(timestamps[fidx])
]
}
datahier1 = get_all_available_quantities_from_h5(secnd_out, timestamps[sidx])
qties = ["protons_domain", "alpha_domain", "Bx", "By", "Bz"]
ds = [single_patch_for_LO(d, qties) for d in [datahier0, datahier1]]
eq = hierarchy_compare(*ds)
skip = None # ["protons_patchGhost", "alpha_patchGhost"]
ds = [single_patch_for_LO(d, qties, skip) for d in [datahier0, datahier1]]
eq = hierarchy_compare(*ds, atol=1e-14)
if not eq:
print(eq)
if type(eq.ref) == FieldData:
phut.assert_fp_any_all_close(eq.ref[:], eq.cmp[:], atol=1e-16)
self.assertTrue(eq)


Expand All @@ -86,7 +101,6 @@ def launch():
cmd = f"mpirun -n {first_mpi_size} python3 -O {__file__} lol"
try:
p = subprocess.run(cmd.split(" "), check=True, capture_output=True)
print(p.stdout, p.stderr)
except subprocess.CalledProcessError as e:
print("CalledProcessError", e)

Expand Down

0 comments on commit 9049ef5

Please sign in to comment.