Skip to content

Commit

Permalink
reset memory utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Jan 24, 2025
1 parent ee368a1 commit d63fbcf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
11 changes: 4 additions & 7 deletions finetrainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)
from .utils.file_utils import string_to_filename
from .utils.hub_utils import save_model_card
from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous, reset_memory_stats
from .utils.model_utils import resolve_vae_cls_from_ckpt_path
from .utils.optimizer_utils import get_optimizer
from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model
Expand Down Expand Up @@ -255,8 +255,7 @@ def collate_fn(batch):

memory_statistics = get_memory_statistics()
logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}")
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(accelerator.device)
reset_memory_stats(accelerator.device)

# Precompute latents
latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
Expand Down Expand Up @@ -303,8 +302,7 @@ def collate_fn(batch):

memory_statistics = get_memory_statistics()
logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}")
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(accelerator.device)
reset_memory_stats(accelerator.device)

# Update dataloader to use precomputed conditions and latents
self.dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -986,8 +984,7 @@ def validate(self, step: int, final_validation: bool = False) -> None:
free_memory()
memory_statistics = get_memory_statistics()
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(accelerator.device)
reset_memory_stats(accelerator.device)

if not final_validation:
self.transformer.train()
Expand Down
8 changes: 8 additions & 0 deletions finetrainers/utils/memory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def free_memory() -> None:
# TODO(aryan): handle non-cuda devices


def reset_memory_stats(device: torch.device):
# TODO: handle for non-cuda devices
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(device)
else:
logger.warning("No CUDA, device found. Memory statistics are not available.")


def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if isinstance(x, torch.Tensor):
return x.contiguous()
Expand Down

0 comments on commit d63fbcf

Please sign in to comment.