diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index b0e2c04a2d0..b043d5d20bc 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -102,7 +102,7 @@ def save_accelerator_state( states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() # ^^ safe to call this function even if cuda is not available if is_tpu_available(): - states["xm_seed"] = torch.tensor(xm.get_rng_state()) + states["xm_seed"] = xm.get_rng_state() output_states_file = os.path.join(output_dir, states_name) torch.save(states, output_states_file) logger.info(f"Random states saved in {output_states_file}")