Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline some helper functions in driver #372

Merged
merged 1 commit into from
Nov 3, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 28 additions & 51 deletions driver/pace/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from pace.dsl.dace.dace_config import DaceConfig
from pace.dsl.dace.orchestration import dace_inhibitor, orchestrate
from pace.dsl.stencil_config import CompilationConfig, RunMode
from pace.fv3core.initialization.dycore_state import DycoreState

# TODO: move update_atmos_state into pace.driver
from pace.stencils import update_atmos_state
Expand Down Expand Up @@ -397,7 +396,7 @@ def __init__(
communicator = CubedSphereCommunicator.from_layout(
comm=self.comm, layout=self.config.layout
)
self.update_driver_config_with_communicator(communicator)
self._update_driver_config_with_communicator(communicator)

if self.config.stencil_config.compilation_config.run_mode == RunMode.Build:

Expand Down Expand Up @@ -436,17 +435,6 @@ def exit_instead_of_build(self):
method_to_orchestrate="_critical_path_step_all",
dace_compiletime_args=["timer"],
)
orchestrate(
obj=self,
config=self.config.stencil_config.dace_config,
method_to_orchestrate="_step_dynamics",
dace_compiletime_args=["state", "timer"],
)
orchestrate(
obj=self,
config=self.config.stencil_config.dace_config,
method_to_orchestrate="_step_physics",
)

self.quantity_factory, self.stencil_factory = _setup_factories(
config=config,
Expand Down Expand Up @@ -525,7 +513,7 @@ def exit_instead_of_build(self):

self._time_run = self.config.start_time

def update_driver_config_with_communicator(
def _update_driver_config_with_communicator(
self, communicator: CubedSphereCommunicator
) -> None:
dace_config = DaceConfig(
Expand Down Expand Up @@ -554,9 +542,10 @@ def _callback_diagnostics(self):
self.diagnostics.store(time=self._time_run, state=self.state)

@dace_inhibitor
def end_of_step_actions(self, step: int):
"""Gather operations unrelated to computation.
Using a function allows those actions to be removed from the orchestration path.
def _end_of_step_actions(self, step: int):
"""
Gather operations unrelated to computation.
Using a method allows those actions to be removed from the orchestration path.
"""
if __debug__:
logger.info(f"Finished stepping {step}")
Expand All @@ -583,16 +572,32 @@ def _critical_path_step_all(

This function must remain orchestrateable by DaCe (e.g.
all code not parsable due to python complexity needs to be moved
to a callback, like end_of_step_actions)."""
to a callback, like end_of_step_actions).
"""
for step in dace.nounroll(range(steps_count)):
with timer.clock("mainloop"):
self._step_dynamics(
self.state.dycore_state,
self.performance_collector.timestep_timer,
self.dycore.step_dynamics(
state=self.state.dycore_state,
timer=timer,
)
if not self.config.disable_step_physics:
self._step_physics(timestep=dt)
self.end_of_step_actions(step)
self.dycore_to_physics(
dycore_state=self.state.dycore_state,
physics_state=self.state.physics_state,
tendency_state=self.state.tendency_state,
timestep=float(dt),
)
if not self.config.dycore_only:
self.physics(self.state.physics_state, timestep=float(dt))
self.end_of_step_update(
dycore_state=self.state.dycore_state,
phy_state=self.state.physics_state,
u_dt=self.state.tendency_state.u_dt.storage,
v_dt=self.state.tendency_state.v_dt.storage,
pt_dt=self.state.tendency_state.pt_dt.storage,
dt=float(dt),
)
self._end_of_step_actions(step)

def step_all(self):
logger.info("integrating driver forward in time")
Expand All @@ -608,34 +613,6 @@ def step_all(self):
{self.comm.Get_rank()}.prof"
)

def _step_dynamics(
self,
state: DycoreState,
timer: pace.util.Timer,
):
self.dycore.step_dynamics(
state=state,
timer=timer,
)

def _step_physics(self, timestep: float):
self.dycore_to_physics(
dycore_state=self.state.dycore_state,
physics_state=self.state.physics_state,
tendency_state=self.state.tendency_state,
timestep=float(timestep),
)
if not self.config.dycore_only:
self.physics(self.state.physics_state, timestep=float(timestep))
self.end_of_step_update(
dycore_state=self.state.dycore_state,
phy_state=self.state.physics_state,
u_dt=self.state.tendency_state.u_dt.storage,
v_dt=self.state.tendency_state.v_dt.storage,
pt_dt=self.state.tendency_state.pt_dt.storage,
dt=float(timestep),
)

def _write_performance_json_output(self):
self.performance_collector.write_out_performance(
self.config.stencil_config.compilation_config.backend,
Expand Down