Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resume correct step for "resume from state" feature. #1353

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,15 @@ def set_caching_mode(self, mode):

def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch
if epoch > self.current_epoch:
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
num_epochs = epoch - self.current_epoch
for _ in range(num_epochs):
self.current_epoch += 1
self.shuffle_buckets()
else:
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
self.current_epoch = epoch

def set_current_step(self, step):
self.current_step = step
Expand Down
108 changes: 101 additions & 7 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,17 +493,24 @@ def train(self, args):
# before resuming make hook for saving/loading to save/load the network weights only
def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
if accelerator.is_main_process or args.deepspeed:
if accelerator.is_main_process:
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
if len(weights) > i:
weights.pop(i)
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved")

# save current ecpoch and step
train_state_file = os.path.join(output_dir, "train_state.json")
# +1 is needed because the state is saved before current_step is set from global_step
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
with open(train_state_file, "w", encoding="utf-8") as f:
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)

steps_from_state = None

def load_model_hook(models, input_dir):
# remove models except network
remove_indices = []
Expand All @@ -514,6 +521,15 @@ def load_model_hook(models, input_dir):
models.pop(i)
# print(f"load model hook: {len(models)} models will be loaded")

# load current epoch and step to
nonlocal steps_from_state
train_state_file = os.path.join(input_dir, "train_state.json")
if os.path.exists(train_state_file):
with open(train_state_file, "r", encoding="utf-8") as f:
data = json.load(f)
steps_from_state = data["current_step"]
logger.info(f"load train state from {train_state_file}: {data}")

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

Expand Down Expand Up @@ -757,7 +773,53 @@ def load_model_hook(models, input_dir):
if key in metadata:
minimum_metadata[key] = metadata[key]

progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
# calculate steps to skip when resuming or starting from a specific step
initial_step = 0
if args.initial_epoch is not None or args.initial_step is not None:
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
if steps_from_state is not None:
logger.warning(
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
)
if args.initial_step is not None:
initial_step = args.initial_step
else:
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
initial_step = (args.initial_epoch - 1) * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
else:
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
if steps_from_state is not None:
initial_step = steps_from_state
steps_from_state = None

if initial_step > 0:
assert (
args.max_train_steps > initial_step
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"

progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)

epoch_to_start = 0
if initial_step > 0:
if args.skip_until_initial_step:
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
if not args.resume:
logger.info(
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
)
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
initial_step *= args.gradient_accumulation_steps
else:
# if not, only epoch no is skipped for informative purpose
epoch_to_start = initial_step // math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
initial_step = 0 # do not skip

global_step = 0

noise_scheduler = DDPMScheduler(
Expand Down Expand Up @@ -816,16 +878,29 @@ def remove_model(old_ckpt_name):
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

# training loop
for epoch in range(num_train_epochs):
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)

for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

metadata["ss_epoch"] = str(epoch + 1)

accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):
skipped_dataloader = None
if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1)
initial_step = 1

for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
continue

with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)

Expand Down Expand Up @@ -1126,6 +1201,25 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--skip_until_initial_step",
action="store_true",
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
)
parser.add_argument(
"--initial_epoch",
type=int,
default=None,
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
)
parser.add_argument(
"--initial_step",
type=int,
default=None,
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
Expand Down