Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distunroller set last step periodically #1725

Merged
merged 4 commits into from
Jan 31, 2025
Merged

distunroller set last step periodically #1725

merged 4 commits into from
Jan 31, 2025

Conversation

hnyu
Copy link
Collaborator

@hnyu hnyu commented Jan 17, 2025

This PR let the DistributedUnroller to truncate the experience stream on its own. The stream is truncated if either the predefined max episode length is reached, or the env returns a LAST step.

After the stream is truncated, the unroller will switch to sending exps to a different trainer worker (if available).

This PR is dependent on PR #1723 which fixes a ReplayBuffer sharing issue among processes. Now I've added a minimal test to the init of DistributedTrainer to make sure that ReplayBuffer can correctly shared with a subprocess.

@hnyu hnyu requested a review from emailweixu January 17, 2025 22:41
@hnyu hnyu requested a review from Haichao-Zhang January 18, 2025 04:31
# One episode finishes; move to the next worker
# We need to make sure a whole episode is always sent to the same
# worker so that the temporal information is preserved.
exp = alf.nest.set_field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of a single trainer workers, we don't need to change the step type to LAST.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of a single trainer workers, we don't need to change the step type to LAST.

If there are multiple unrollers, we still need to set LAST. But it's not straightforward for an unroller to know if there is any other unroller, unless via the trainer. So for simplicity, here we always set LAST.

if self._exp_socket is None:
self._exp_socket, _ = create_zmq_socket(zmq.ROUTER, '*',
self._port, self._id)

try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should send only for LAST step or episode length reached.

Copy link
Collaborator Author

@hnyu hnyu Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should send only for LAST step or episode length reached.

Right now, we always send on a per-exp basis, instead of waiting for a long traj. The trainer is responsible for maintaining the traj integrity. The reason is for latency concern, because sending a very long traj might take a long time (especially with images), blocking the unroller.

self._num_earliest_frames_ignored = self._core_alg._num_earliest_frames_ignored

# We always test tensor sharing among processes, because
# we rely on undocumented features of PyTorch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain what the undocumented feature is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain what the undocumented feature is.

added explanation

process.join()

# numpy array should not be modified
assert np.allclose(m.y, np.zeros([2]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be equal instead of close

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be equal instead of close

updated

@hnyu hnyu requested a review from emailweixu January 24, 2025 19:56
# Add the temp exp buffer to the replay buffer
for exp_params in unroller_exps_buffer[unroller_id]:
for i, exp_params in enumerate(
unroller_exps_buffer[unroller_id]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the batch size of the replay buffer is 1. env_id has to be 0 at the next line

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the batch size of the replay buffer is 1. env_id has to be 0 at the next line

This is true for the current assumption. But since exp_params always contains env_id, we can just use it. Do you mean we should assert it's equal to 0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just set it to 0 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just set it to 0 here?

updated

# Add the temp exp buffer to the replay buffer
for exp_params in unroller_exps_buffer[unroller_id]:
replay_buffer.add_batch(exp_params, exp_params.env_id)
env_id = torch.zeros([1], dtype=torch.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device="cpu"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should change exp_arams.env_id instead, since exp_params.env_id will be stored in replay buffer and we dont' want inconsistency.

exp_params.env_id.zero_()

@hnyu hnyu force-pushed the PR_unroller_set_last branch from 48969cd to 53f6e4d Compare January 28, 2025 01:03
@hnyu hnyu requested a review from emailweixu January 31, 2025 17:39
@hnyu hnyu merged commit e140cd2 into pytorch Jan 31, 2025
2 checks passed
@hnyu hnyu deleted the PR_unroller_set_last branch January 31, 2025 17:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants