Skip to content

Commit

Permalink
speed up nan replace in sdxl training ref #1009
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Dec 21, 2023
1 parent 0676f1a commit 04ef8d3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR

if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def remove_model(old_ckpt_name):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR

if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def remove_model(old_ckpt_name):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR

if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
Expand Down

0 comments on commit 04ef8d3

Please sign in to comment.