Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Fix GPU memory leak that came up in RCN example #97

Merged
merged 4 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ repos:
hooks:
- id: mypy
additional_dependencies: [tokenize-rt==3.2.0]
ci:
autoupdate_schedule: 'quarterly'
73 changes: 42 additions & 31 deletions pgmax/fg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Sequence,
Tuple,
Union,
cast,
)

import jax
Expand All @@ -45,7 +46,7 @@ class FactorGraph:
"""

variables: Union[
Mapping[Hashable, groups.VariableGroup],
Mapping[Any, groups.VariableGroup],
Sequence[groups.VariableGroup],
groups.VariableGroup,
]
Expand Down Expand Up @@ -399,19 +400,17 @@ class LogPotentials:

def __post_init__(self):
if self.value is None:
object.__setattr__(
self, "value", jax.device_put(self.fg_state.log_potentials)
)
object.__setattr__(self, "value", self.fg_state.log_potentials)
else:
if not self.value.shape == self.fg_state.log_potentials.shape:
raise ValueError(
f"Expected log potentials shape {self.fg_state.log_potentials.shape}. "
f"Got {self.value.shape}."
)

object.__setattr__(self, "value", jax.device_put(self.value))
object.__setattr__(self, "value", self.value)

def __getitem__(self, name: Any) -> jnp.ndarray:
def __getitem__(self, name: Any) -> np.ndarray:
"""Function to query log potentials for a named factor group or a factor.

Args:
Expand All @@ -421,21 +420,20 @@ def __getitem__(self, name: Any) -> jnp.ndarray:
Returns:
The queried log potentials.
"""
value = cast(np.ndarray, self.value)
if not isinstance(name, Hashable):
name = frozenset(name)

if name in self.fg_state.named_factor_groups:
factor_group = self.fg_state.named_factor_groups[name]
start = self.fg_state.factor_group_to_potentials_starts[factor_group]
log_potentials = jax.device_put(self.value)[
log_potentials = value[
start : start + factor_group.factor_group_log_potentials.shape[0]
]
elif frozenset(name) in self.fg_state.variables_to_factors:
factor = self.fg_state.variables_to_factors[frozenset(name)]
start = self.fg_state.factor_to_potentials_starts[factor]
log_potentials = jax.device_put(self.value)[
start : start + factor.log_potentials.shape[0]
]
log_potentials = value[start : start + factor.log_potentials.shape[0]]
else:
raise ValueError(f"Invalid name {name} for log potentials updates.")

Expand All @@ -460,8 +458,12 @@ def __setitem__(
object.__setattr__(
self,
"value",
update_log_potentials(
jax.device_put(self.value), {name: jax.device_put(data)}, self.fg_state
np.asarray(
update_log_potentials(
jax.device_put(self.value),
{name: jax.device_put(data)},
self.fg_state,
)
),
)

Expand Down Expand Up @@ -545,12 +547,12 @@ class FToVMessages:
"""

fg_state: FactorGraphState
value: Optional[Union[np.ndarray, jnp.ndarray]] = None
value: Optional[np.ndarray] = None

def __post_init__(self):
if self.value is None:
object.__setattr__(
self, "value", jnp.zeros(self.fg_state.total_factor_num_states)
self, "value", np.zeros(self.fg_state.total_factor_num_states)
)
else:
if not self.value.shape == (self.fg_state.total_factor_num_states,):
Expand All @@ -559,9 +561,9 @@ def __post_init__(self):
f"Got {self.value.shape}."
)

object.__setattr__(self, "value", jax.device_put(self.value))
object.__setattr__(self, "value", self.value)

def __getitem__(self, names: Tuple[Any, Any]) -> jnp.ndarray:
def __getitem__(self, names: Tuple[Any, Any]) -> np.ndarray:
"""Function to query messages from a factor to a variable

Args:
Expand All @@ -574,6 +576,7 @@ def __getitem__(self, names: Tuple[Any, Any]) -> jnp.ndarray:

Raises: ValueError if provided names are not valid for querying ftov messages.
"""
value = cast(np.ndarray, self.value)
if not (
isinstance(names, tuple)
and len(names) == 2
Expand All @@ -589,7 +592,7 @@ def __getitem__(self, names: Tuple[Any, Any]) -> jnp.ndarray:
start = self.fg_state.factor_to_msgs_starts[factor] + np.sum(
factor.edges_num_states[: factor.variables.index(variable)]
)
msgs = jax.device_put(self.value)[start : start + variable.num_states]
msgs = value[start : start + variable.num_states]
return msgs

@typing.overload
Expand Down Expand Up @@ -634,8 +637,12 @@ def __setitem__(self, names, data) -> None:
object.__setattr__(
self,
"value",
update_ftov_msgs(
jax.device_put(self.value), {names: jax.device_put(data)}, self.fg_state
np.asarray(
update_ftov_msgs(
jax.device_put(self.value),
{names: jax.device_put(data)},
self.fg_state,
)
),
)

Expand Down Expand Up @@ -690,21 +697,21 @@ class Evidence:
"""

fg_state: FactorGraphState
value: Optional[Union[np.ndarray, jnp.ndarray]] = None
value: Optional[np.ndarray] = None

def __post_init__(self):
if self.value is None:
object.__setattr__(self, "value", jnp.zeros(self.fg_state.num_var_states))
object.__setattr__(self, "value", np.zeros(self.fg_state.num_var_states))
else:
if self.value.shape != (self.fg_state.num_var_states,):
raise ValueError(
f"Expected evidence shape {(self.fg_state.num_var_states,)}. "
f"Got {self.value.shape}."
)

object.__setattr__(self, "value", jax.device_put(self.value))
object.__setattr__(self, "value", self.value)

def __getitem__(self, name: Any) -> jnp.ndarray:
def __getitem__(self, name: Any) -> np.ndarray:
"""Function to query evidence for a variable

Args:
Expand All @@ -713,9 +720,10 @@ def __getitem__(self, name: Any) -> jnp.ndarray:
Returns:
evidence for the queried variable
"""
value = cast(np.ndarray, self.value)
variable = self.fg_state.variable_group[name]
start = self.fg_state.vars_to_starts[variable]
evidence = jax.device_put(self.value)[start : start + variable.num_states]
evidence = value[start : start + variable.num_states]
return evidence

def __setitem__(
Expand All @@ -736,8 +744,12 @@ def __setitem__(
object.__setattr__(
self,
"value",
update_evidence(
jax.device_put(self.value), {name: jax.device_put(data)}, self.fg_state
np.asarray(
update_evidence(
jax.device_put(self.value),
{name: jax.device_put(data)},
self.fg_state,
),
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not expert on jax but for my own clarification.
Do we not need to clear the variables self.value and data from jax memory memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.value and data would stay on CPU. Intermediate objects on GPU will be cleared automatically if you use the platform option.

The leak came from a failure to clear large jit compilation cache (the cache is large because it involves a large constant array, i.e. the wiring). A completely leak-free way is to include wiring as part of the arguments, but in my experiments that leads to slow compiling and hurts performance. Current solution is a bit of a compromise but should suffice for most cases.

)

Expand Down Expand Up @@ -789,7 +801,6 @@ def BP(bp_state: BPState, num_iters: int) -> Tuple[Callable, Callable, Callable]
int(bp_state.fg_state.wiring.factor_configs_edge_states[-1, 0]) + 1
)

@jax.jit
def run_bp(
log_potentials_updates: Optional[Dict[Any, jnp.ndarray]] = None,
ftov_msgs_updates: Optional[Dict[Any, jnp.ndarray]] = None,
Expand All @@ -810,20 +821,20 @@ def run_bp(
Returns:
A BPArrays containing the updated log_potentials, ftov_msgs and evidence.
"""
wiring = jax.device_put(bp_state.fg_state.wiring)
log_potentials = jax.device_put(bp_state.log_potentials.value)
wiring = bp_state.fg_state.wiring
log_potentials = bp_state.log_potentials.value
if log_potentials_updates is not None:
log_potentials = update_log_potentials(
log_potentials, log_potentials_updates, bp_state.fg_state
)

ftov_msgs = jax.device_put(bp_state.ftov_msgs.value)
ftov_msgs = bp_state.ftov_msgs.value
if ftov_msgs_updates is not None:
ftov_msgs = update_ftov_msgs(
ftov_msgs, ftov_msgs_updates, bp_state.fg_state
)

evidence = jax.device_put(bp_state.evidence.value)
evidence = bp_state.evidence.value
if evidence_updates is not None:
evidence = update_evidence(evidence, evidence_updates, bp_state.fg_state)

Expand Down
5 changes: 3 additions & 2 deletions tests/test_pgmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.random import default_rng
from scipy.ndimage import gaussian_filter

from pgmax.fg import graph, groups
from pgmax.fg import graph, groups, nodes

# Set random seed for rng
rng = default_rng(23)
Expand Down Expand Up @@ -365,6 +365,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
# Run BP
# Set the evidence
bp_state = fg.bp_state
assert isinstance(jax.device_put(fg.fg_state.wiring), nodes.EnumerationWiring)
bp_state.evidence["grid_vars"] = grid_evidence_arr
bp_state.evidence["additional_vars"] = additional_vars_evidence_dict
run_bp, _, get_beliefs = graph.BP(bp_state, 100)
Expand Down Expand Up @@ -422,5 +423,5 @@ def binary_connected_variables(
bp_state.evidence[0, 0, 0] = np.array([0.0, 0.0, 0.0])
bp_state.evidence[0, 0, 0]
bp_state.evidence[1, 0, 0]
assert isinstance(bp_state.evidence.value, jnp.ndarray)
assert isinstance(bp_state.evidence.value, np.ndarray)
assert len(fg.factors) == 7056