Skip to content

Commit

Permalink
Merge pull request #86 from okotaku/feat/compile_mode
Browse files Browse the repository at this point in the history
[Enhance] Support compile mode
  • Loading branch information
okotaku authored Oct 30, 2023
2 parents 5786285 + 1af92c4 commit e93f733
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions diffengine/engine/hooks/compile_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ class CompileHook(Hook):
----
backend (str): The backend to use for compilation.
Defaults to "inductor".
mode (str): The mode to use for compilation. Defaults to None.
compile_unet (bool): Whether to compile the unet. Defaults to False.
"""

priority = "VERY_LOW"

def __init__(self, backend: str = "inductor", *,
def __init__(self, backend: str = "inductor", mode: str | None = None, *,
compile_unet: bool = False) -> None:
super().__init__()
self.backend = backend
self.mode = mode
self.compile_unet = compile_unet

def before_train(self, runner) -> None:
Expand All @@ -34,17 +36,18 @@ def before_train(self, runner) -> None:
if is_model_wrapper(model):
model = model.module
if self.compile_unet:
model.unet = torch.compile(model.unet, backend=self.backend)
model.unet = torch.compile(model.unet, backend=self.backend,
mode=self.mode)

if hasattr(model, "text_encoder"):
model.text_encoder = torch.compile(
model.text_encoder, backend=self.backend)
model.text_encoder, backend=self.backend, mode=self.mode)
if hasattr(model, "text_encoder_one"):
model.text_encoder_one = torch.compile(
model.text_encoder_one, backend=self.backend)
model.text_encoder_one, backend=self.backend, mode=self.mode)
if hasattr(model, "text_encoder_two"):
model.text_encoder_two = torch.compile(
model.text_encoder_two, backend=self.backend)
model.text_encoder_two, backend=self.backend, mode=self.mode)
if hasattr(model, "vae"):
model.vae = torch.compile(
model.vae, backend=self.backend)
model.vae, backend=self.backend, mode=self.mode)

0 comments on commit e93f733

Please sign in to comment.