Skip to content

Commit

Permalink
Update aoti calls to utilize new export and packaging APIs (#1455)
Browse files Browse the repository at this point in the history
Co-authored-by: Jack-Khuu <[email protected]>
  • Loading branch information
2 people authored and vmpuri committed Feb 4, 2025
1 parent a64b9e3 commit 84d2232
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
3 changes: 1 addition & 2 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,8 @@ def do_nothing(max_batch_size, max_seq_length):
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
from torch._inductor.package import load_package

aoti_compiled_model = load_package(
aoti_compiled_model = torch._inductor.aoti_load_package(
str(builder_args.aoti_package_path.absolute())
)

Expand Down
13 changes: 8 additions & 5 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,20 @@ def export_for_server(
if not package:
options = {"aot_inductor.output_path": output_path}

path = torch._export.aot_compile(
ep = torch.export.export(
model,
example_inputs,
dynamic_shapes=dynamic_shapes,
options=options,
)

if package:
from torch._inductor.package import package_aoti

path = package_aoti(output_path, path)
path = torch._inductor.aoti_compile_and_package(
ep, package_path=output_path, inductor_configs=options
)
else:
path = torch._inductor.aot_compile(
ep.module(), example_inputs, options=options
)

print(f"The generated packaged model can be found at: {path}")
return path
Expand Down

0 comments on commit 84d2232

Please sign in to comment.