From f5f9cc02e98ea3c5469559b18404c26830d0dca1 Mon Sep 17 00:00:00 2001 From: Aryan Gupta <97878444+guptaaryan16@users.noreply.github.com> Date: Tue, 14 Jan 2025 23:01:28 +0530 Subject: [PATCH] Fix: utils error in michi finetuning script (#218) --- training/mochi-1/text_to_video_lora.py | 4 ---- training/mochi-1/utils.py | 22 ++++++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) create mode 100644 training/mochi-1/utils.py diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 89caf22f..af1ce626 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -40,10 +40,6 @@ from dataset_simple import LatentEmbedDataset import sys - - -sys.path.append("..") - from utils import print_memory, reset_memory # isort:skip diff --git a/training/mochi-1/utils.py b/training/mochi-1/utils.py new file mode 100644 index 00000000..76fe35c2 --- /dev/null +++ b/training/mochi-1/utils.py @@ -0,0 +1,22 @@ +import gc +import inspect +from typing import Optional, Tuple, Union + +import torch + +logger = get_logger(__name__) + +def reset_memory(device: Union[str, torch.device]) -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.reset_accumulated_memory_stats(device) + + +def print_memory(device: Union[str, torch.device]) -> None: + memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 + max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 + max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 + print(f"{memory_allocated=:.3f} GB") + print(f"{max_memory_allocated=:.3f} GB") + print(f"{max_memory_reserved=:.3f} GB")