From 04ef8d395f3ba3b012c62546325436477e064a08 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 21 Dec 2023 21:44:03 +0900 Subject: [PATCH] speed up nan replace in sdxl training ref #1009 --- sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index aa2eb5dfd..8983673d2 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -487,7 +487,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index cb97859fa..18c6bd053 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -394,7 +394,7 @@ def remove_model(old_ckpt_name): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 87f303018..6ae5377ba 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -363,7 +363,7 @@ def remove_model(old_ckpt_name): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: