diff --git a/training/mochi-1/text_to_video_lora.py b/training/mochi-1/text_to_video_lora.py index 89caf22f..af1ce626 100644 --- a/training/mochi-1/text_to_video_lora.py +++ b/training/mochi-1/text_to_video_lora.py @@ -40,10 +40,6 @@ from dataset_simple import LatentEmbedDataset import sys - - -sys.path.append("..") - from utils import print_memory, reset_memory # isort:skip diff --git a/training/mochi-1/utils.py b/training/mochi-1/utils.py new file mode 100644 index 00000000..76fe35c2 --- /dev/null +++ b/training/mochi-1/utils.py @@ -0,0 +1,22 @@ +import gc +import inspect +from typing import Optional, Tuple, Union + +import torch + +logger = get_logger(__name__) + +def reset_memory(device: Union[str, torch.device]) -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.reset_accumulated_memory_stats(device) + + +def print_memory(device: Union[str, torch.device]) -> None: + memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 + max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 + max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 + print(f"{memory_allocated=:.3f} GB") + print(f"{max_memory_allocated=:.3f} GB") + print(f"{max_memory_reserved=:.3f} GB")