Skip to content

Commit

Permalink
Fix memory_planning API to use run() (#8622)
Browse files Browse the repository at this point in the history
Summary:

Update memory_planning pass to use run() instead of deprecated __call__()

Reviewed By: zonglinpeng

Differential Revision: D68939461
  • Loading branch information
Eashan Garg authored and facebook-github-bot committed Feb 21, 2025
1 parent ad4675a commit bdbe122
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
28 changes: 22 additions & 6 deletions backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:

def collect_specs_from_graph_module(
graph_module: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
alloc_graph_input: bool,
alloc_graph_output: bool,
) -> Iterable[TensorSpec]:
Expand All @@ -56,6 +57,7 @@ def collect_specs_from_graph_module(
# Collect the specs from all the nodes in the graph module, and return it
return collect_specs_from_nodes(
graph_module.graph.nodes,
graph_signature,
ignore_graph_input=not alloc_graph_input,
ignore_graph_output=not alloc_graph_output,
)
Expand Down Expand Up @@ -107,7 +109,7 @@ def memory_available(spec: TensorSpec) -> bool:
# Iterate over all the specs in sorted order
for spec in sorted(
collect_specs_from_graph_module(
graph_module, alloc_graph_input, alloc_graph_output
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
),
key=lambda spec: spec.allocated_memory,
reverse=True,
Expand Down Expand Up @@ -182,7 +184,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
# Iterate over all the specs in sorted order
for spec in sorted(
collect_specs_from_graph_module(
graph_module, alloc_graph_input, alloc_graph_output
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
),
key=lambda spec: spec.allocated_memory,
reverse=True,
Expand Down Expand Up @@ -250,6 +252,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(

def find_peak_memory_usages_per_memory(
graph_module: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
alloc_graph_input: bool,
alloc_graph_output: bool,
mem_constraints: Optional[MemConstraints] = None,
Expand All @@ -265,7 +268,7 @@ def find_peak_memory_usages_per_memory(

# go through all nodes in the graph, collect memory usage per spec.mem_id
for spec in collect_specs_from_graph_module(
graph_module, alloc_graph_input, alloc_graph_output
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
):
if mem_constraints is not None and mem_constraints.skipped_spec(spec):
continue
Expand All @@ -288,6 +291,7 @@ def find_peak_memory_usages_per_memory(

def find_peak_memory_usage(
graph_module: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
alloc_graph_input: bool,
alloc_graph_output: bool,
mem_constraints: Optional[MemConstraints] = None,
Expand All @@ -303,7 +307,7 @@ def find_peak_memory_usage(

# Iterate over all the node specs
for spec in collect_specs_from_graph_module(
graph_module, alloc_graph_input, alloc_graph_output
graph_module, graph_signature, alloc_graph_input, alloc_graph_output
):
if spec.lifetime[0] is None or (
mem_constraints is not None and mem_constraints.skipped_spec(spec)
Expand Down Expand Up @@ -358,6 +362,7 @@ def print_memory_planning_info(
# Get the peak memory usages per memory space
peak_memory_usages_per_memory = find_peak_memory_usages_per_memory(
executorch_prog.exported_program().graph_module,
executorch_prog.exported_program().graph_signature,
alloc_graph_input,
alloc_graph_output,
mem_constraints,
Expand Down Expand Up @@ -393,6 +398,7 @@ def print_memory_planning_info(
# Get the total peak memory usage across all memory spaces
total_peak_memory_usage = find_peak_memory_usage(
executorch_prog.exported_program().graph_module,
executorch_prog.exported_program().graph_signature,
alloc_graph_input,
alloc_graph_output,
mem_constraints,
Expand Down Expand Up @@ -453,7 +459,17 @@ def _init_mem_algos(self) -> None:
greedy_by_size_for_offset_calculation_with_hierarchy,
]

def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
def __call__(
self,
graph_module: torch.fx.GraphModule,
) -> PassResult:
return self.run(graph_module)

def run(
self,
graph_module: torch.fx.GraphModule,
graph_signature: Optional[ExportGraphSignature] = None,
) -> PassResult:
mem_constraints = MemConstraints(
opt_level=self.opt_level,
alloc_graph_input=self.alloc_graph_input,
Expand All @@ -475,6 +491,6 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
alloc_graph_output=self.alloc_graph_output,
alignment=self.mem_alignment,
)
mem_planning(graph_module)
mem_planning.run(graph_module, graph_signature)

return PassResult(graph_module, True)
23 changes: 11 additions & 12 deletions backends/cadence/aot/tests/test_memory_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int:
inputs = (torch.ones(batch_size, input_dim),)
model = PeakMemoryTestModel(input_dim, hidden_dim, output_dim)

graph_module = (
compiler.export_to_executorch_gen_etrecord(model, inputs)
.exported_program()
.graph_module
)
exported_program = compiler.export_to_executorch_gen_etrecord(
model, inputs
).exported_program()

peak_usage, _ = find_peak_memory_usage(
graph_module,
exported_program.graph_module,
exported_program.graph_signature,
mem_constraints=None,
alloc_graph_input=True,
alloc_graph_output=True,
Expand All @@ -73,14 +72,13 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int:
input_dim, hidden_dim, hidden_dim, hidden_dim, output_dim
)

graph_module = (
compiler.export_to_executorch_gen_etrecord(model, inputs)
.exported_program()
.graph_module
)
exported_program = compiler.export_to_executorch_gen_etrecord(
model, inputs
).exported_program()

peak_usage, _ = find_peak_memory_usage(
graph_module,
exported_program.graph_module,
exported_program.graph_signature,
mem_constraints=None,
alloc_graph_input=True,
alloc_graph_output=True,
Expand Down Expand Up @@ -111,6 +109,7 @@ def forward(self, x):
graph_module.graph.eliminate_dead_code()
peak_usage, _ = find_peak_memory_usage(
graph_module,
executorch_prog.exported_program().graph_signature,
alloc_graph_input=False,
alloc_graph_output=False,
mem_constraints=None,
Expand Down

0 comments on commit bdbe122

Please sign in to comment.