diff --git a/diffengine/engine/hooks/compile_hook.py b/diffengine/engine/hooks/compile_hook.py index 60c5ff5..43d5c29 100644 --- a/diffengine/engine/hooks/compile_hook.py +++ b/diffengine/engine/hooks/compile_hook.py @@ -12,15 +12,17 @@ class CompileHook(Hook): ---- backend (str): The backend to use for compilation. Defaults to "inductor". + mode (str): The mode to use for compilation. Defaults to None. compile_unet (bool): Whether to compile the unet. Defaults to False. """ priority = "VERY_LOW" - def __init__(self, backend: str = "inductor", *, + def __init__(self, backend: str = "inductor", mode: str | None = None, *, compile_unet: bool = False) -> None: super().__init__() self.backend = backend + self.mode = mode self.compile_unet = compile_unet def before_train(self, runner) -> None: @@ -34,17 +36,18 @@ def before_train(self, runner) -> None: if is_model_wrapper(model): model = model.module if self.compile_unet: - model.unet = torch.compile(model.unet, backend=self.backend) + model.unet = torch.compile(model.unet, backend=self.backend, + mode=self.mode) if hasattr(model, "text_encoder"): model.text_encoder = torch.compile( - model.text_encoder, backend=self.backend) + model.text_encoder, backend=self.backend, mode=self.mode) if hasattr(model, "text_encoder_one"): model.text_encoder_one = torch.compile( - model.text_encoder_one, backend=self.backend) + model.text_encoder_one, backend=self.backend, mode=self.mode) if hasattr(model, "text_encoder_two"): model.text_encoder_two = torch.compile( - model.text_encoder_two, backend=self.backend) + model.text_encoder_two, backend=self.backend, mode=self.mode) if hasattr(model, "vae"): model.vae = torch.compile( - model.vae, backend=self.backend) + model.vae, backend=self.backend, mode=self.mode)