From 2c9db5d9f2f6b57f15b9312139d0410ae8ae4f3c Mon Sep 17 00:00:00 2001 From: Maatra Date: Sat, 20 Apr 2024 14:11:43 +0100 Subject: [PATCH 1/2] passing filtered hyperparameters to accelerate --- fine_tune.py | 2 +- library/train_util.py | 14 ++++++++++++++ sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_controlnet.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 10 files changed, 23 insertions(+), 9 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index c7e6bbd2e..77a1a4f30 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -310,7 +310,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..40be2b05b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3378,6 +3378,20 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", ) +def filter_sensitive_args(args: argparse.Namespace): + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "pretrained_model_name_or_path", + "vae", + "tokenizer_cache_dir", + "train_data_dir", + "conditioning_data_dir", + "reg_data_dir", + "output_dir", + "logging_dir", + ] + filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args} + return filtered_args # verify command line args for training def verify_command_line_training_args(args: argparse.Namespace): diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..5a9aa214e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -487,7 +487,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) # For --sample_at_first sdxl_train_util.sample_images( diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f89c3628f..770a1f3df 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -353,7 +353,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index e85e978c1..9490cf6f2 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -324,7 +324,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_controlnet.py b/train_controlnet.py index f4c94e8d9..793f79c7d 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -344,7 +344,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_db.py b/train_db.py index 1de504ed8..4f9018293 100644 --- a/train_db.py +++ b/train_db.py @@ -290,7 +290,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) diff --git a/train_network.py b/train_network.py index c99d37247..1dca437cf 100644 --- a/train_network.py +++ b/train_network.py @@ -753,7 +753,7 @@ def load_model_hook(models, input_dir): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10fce2677..56a387391 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -510,7 +510,7 @@ def train(self, args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) # function for saving/removing diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ddd03d532..691785239 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) # function for saving/removing From b886d0a359526f5715f3ced05697d406a169055b Mon Sep 17 00:00:00 2001 From: Maatra Date: Sat, 20 Apr 2024 14:36:47 +0100 Subject: [PATCH 2/2] Cleaned typing to be in line with accelerate hyperparameters type resctrictions --- library/train_util.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 40be2b05b..75b3420d9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3390,7 +3390,20 @@ def filter_sensitive_args(args: argparse.Namespace): "output_dir", "logging_dir", ] - filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args} + filtered_args = {} + for k, v in vars(args).items(): + # filter out sensitive values + if k not in sensitive_args + sensitive_path_args: + #Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): + filtered_args[k] = v + # accelerate does not support lists + elif isinstance(v, list): + filtered_args[k] = f"{v}" + # accelerate does not support objects + elif isinstance(v, object): + filtered_args[k] = f"{v}" + return filtered_args # verify command line args for training