-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
distunroller set last step periodically #1725
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from typing import Callable | ||
import time | ||
import io | ||
import queue | ||
import random | ||
import threading | ||
import subprocess | ||
|
@@ -24,21 +25,21 @@ | |
import torch | ||
import torch.multiprocessing as mp | ||
from multiprocessing.shared_memory import SharedMemory | ||
from multiprocessing import Manager | ||
|
||
import alf | ||
from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm | ||
from alf.algorithms.config import TrainerConfig | ||
from alf.environments.alf_environment import AlfEnvironment | ||
from alf.experience_replayers.replay_buffer import ReplayBuffer | ||
from alf.data_structures import Experience, make_experience | ||
from alf.data_structures import Experience, make_experience, StepType | ||
from alf.utils.per_process_context import PerProcessContext | ||
from alf.utils import dist_utils | ||
from alf.utils.summary_utils import record_time | ||
from alf.utils.common_test import _test_tensor_sharing | ||
|
||
|
||
class UnrollerMessage(object): | ||
# unroller indicates end of experience for the current segment | ||
EXP_SEG_END = 'unroller: last_seg_exp' | ||
# confirmation | ||
OK = 'unroller: ok' | ||
|
||
|
@@ -104,7 +105,6 @@ def __init__(self, | |
env: AlfEnvironment = None, | ||
config: TrainerConfig = None, | ||
optimizer: alf.optimizers.Optimizer = None, | ||
checkpoint: str = None, | ||
debug_summaries: bool = False, | ||
name: str = "DistributedOffPolicyAlgorithm", | ||
**kwargs): | ||
|
@@ -116,10 +116,6 @@ def __init__(self, | |
port: port number for communication on the *current* machine. | ||
env: The environment to interact with. Its batch size must be 1. | ||
optimizer: optimizer for the training the core algorithm. | ||
checkpoint: a string in the format of "prefix@path", | ||
where the "prefix" is the multi-step path to the contents in the | ||
checkpoint to be loaded. "path" is the full path to the checkpoint | ||
file saved by ALF. Refer to ``Algorithm`` for more details. | ||
debug_summaries: True if debug summaries should be created. | ||
name: the name of this algorithm. | ||
*args: args to pass to ``core_alg_ctor``. | ||
|
@@ -147,7 +143,8 @@ def __init__(self, | |
env=env, | ||
config=config, | ||
optimizer=optimizer, | ||
checkpoint=checkpoint, | ||
# Prevent in-alg ckpt since there is no such a use case. | ||
checkpoint=None, | ||
debug_summaries=debug_summaries, | ||
name=name) | ||
|
||
|
@@ -202,7 +199,7 @@ def after_train_iter(self, root_inputs, rollout_info): | |
|
||
|
||
def receive_experience_data(replay_buffer: ReplayBuffer, | ||
new_unroller_ips_and_ports: mp.Queue, | ||
new_unroller_ips_and_ports: 'Manager.Queue', | ||
worker_id: int) -> None: | ||
"""A worker function for consistently receiving experience data from | ||
unrollers. | ||
|
@@ -230,8 +227,9 @@ def receive_experience_data(replay_buffer: ReplayBuffer, | |
socket = None | ||
# Listen for experience data forever | ||
while True: | ||
while not new_unroller_ips_and_ports.empty(): | ||
unroller_ip, unroller_port = new_unroller_ips_and_ports.get() | ||
try: | ||
unroller_ip, unroller_port = new_unroller_ips_and_ports.get_nowait( | ||
) | ||
# A new unroller has connected to the trainer | ||
if socket is None: | ||
socket, _ = create_zmq_socket(zmq.DEALER, unroller_ip, | ||
|
@@ -241,21 +239,26 @@ def receive_experience_data(replay_buffer: ReplayBuffer, | |
addr = 'tcp://' + ':'.join([unroller_ip, str(unroller_port)]) | ||
# Connect to an additional ROUTER | ||
socket.connect(addr) | ||
except queue.Empty: | ||
pass | ||
|
||
if socket is not None: | ||
# Receive data from any router | ||
unroller_id, message = socket.recv_multipart() | ||
if message == UnrollerMessage.EXP_SEG_END.encode(): | ||
|
||
buffer = io.BytesIO(message) | ||
exp_params = torch.load(buffer, map_location='cpu') | ||
# Use a temp buffer to store the received exps | ||
if unroller_id not in unroller_exps_buffer: | ||
unroller_exps_buffer[unroller_id] = [] | ||
unroller_exps_buffer[unroller_id].append(exp_params) | ||
|
||
if int(exp_params.step_type) == StepType.LAST: | ||
# Add the temp exp buffer to the replay buffer | ||
for exp_params in unroller_exps_buffer[unroller_id]: | ||
for i, exp_params in enumerate( | ||
unroller_exps_buffer[unroller_id]): | ||
replay_buffer.add_batch(exp_params, exp_params.env_id) | ||
unroller_exps_buffer[unroller_id] = [] | ||
else: | ||
buffer = io.BytesIO(message) | ||
exp_params = torch.load(buffer, map_location='cpu') | ||
# Use a temp buffer to store the received exps | ||
if unroller_id not in unroller_exps_buffer: | ||
unroller_exps_buffer[unroller_id] = [] | ||
unroller_exps_buffer[unroller_id].append(exp_params) | ||
else: | ||
time.sleep(0.1) | ||
|
||
|
@@ -287,8 +290,7 @@ def pull_params_from_trainer(memory_name: str, unroller_id: str, | |
|
||
|
||
@alf.configurable(whitelist=[ | ||
'max_utd_ratio', 'push_params_every_n_grad_updates', 'checkpoint', 'name', | ||
'optimizer' | ||
'max_utd_ratio', 'push_params_every_n_grad_updates', 'name', 'optimizer' | ||
]) | ||
class DistributedTrainer(DistributedOffPolicyAlgorithm): | ||
def __init__(self, | ||
|
@@ -299,7 +301,6 @@ def __init__(self, | |
env: AlfEnvironment = None, | ||
config: TrainerConfig = None, | ||
optimizer: alf.optimizers.Optimizer = None, | ||
checkpoint: str = None, | ||
debug_summaries: bool = False, | ||
name: str = "DistributedTrainer", | ||
**kwargs): | ||
|
@@ -329,7 +330,6 @@ def __init__(self, | |
env=env, | ||
config=config, | ||
optimizer=optimizer, | ||
checkpoint=checkpoint, | ||
debug_summaries=debug_summaries, | ||
name=name, | ||
**kwargs) | ||
|
@@ -358,6 +358,16 @@ def __init__(self, | |
# may be incremented every mini-batch | ||
self._num_train_iters = 0 | ||
|
||
# respect core_alg's replay buffer setting | ||
self._num_earliest_frames_ignored = self._core_alg._num_earliest_frames_ignored | ||
|
||
# We always test tensor sharing among processes, because | ||
# we rely on undocumented features of PyTorch: | ||
# 1. tensors will automatically be moved to shared memory, even without | ||
# ``Module.share_memory()`` or ``Tensor.share_memory_()`` being called. | ||
# 2. only a 'spawned' subprocess is reliable for tensor sharing. | ||
_test_tensor_sharing() | ||
|
||
def _observe_for_replay(self, exp: Experience): | ||
raise RuntimeError( | ||
'observe_for_replay should not be called for trainer') | ||
|
@@ -409,7 +419,8 @@ def _send_params_to_unroller(self, | |
return False | ||
|
||
def _create_unroller_registration_thread(self): | ||
self._new_unroller_ips_and_ports = mp.Queue() | ||
manager = Manager() | ||
self._new_unroller_ips_and_ports = manager.Queue() | ||
self._unrollers_to_update_params = set() | ||
registered_unrollers = set() | ||
|
||
|
@@ -476,15 +487,12 @@ def _create_data_receiver_subprocess(self): | |
exp = alf.utils.common.prune_exp_replay_state( | ||
exp, self._use_rollout_state, self.rollout_state_spec, | ||
self.train_state_spec) | ||
alf.config('ReplayBuffer', allow_multiprocess=True) | ||
self._set_replay_buffer(exp) | ||
|
||
# In the case of DDP, each subprocess is spawned. By default, if we create | ||
# a new subprocess, the default start method inherited is spawn. In this case, | ||
# we need to explicitly set the start method to fork, so that the daemon | ||
# subprocess can share torch modules. | ||
mp.set_start_method('fork', force=True) | ||
mp.set_start_method('spawn', force=True) | ||
# start the data receiver subprocess | ||
# Need to create the subprocess with 'spawn' so that we can pass a Module | ||
# object to subprocess with tensors in shared memory. | ||
process = mp.Process( | ||
target=receive_experience_data, | ||
args=(self._replay_buffer, self._new_unroller_ips_and_ports, | ||
|
@@ -549,14 +557,14 @@ def _train_iter_off_policy(self): | |
return steps | ||
|
||
|
||
@alf.configurable(whitelist=['deploy_mode', 'checkpoint', 'name', 'optimizer']) | ||
@alf.configurable(whitelist=['episode_length', 'name', 'optimizer']) | ||
class DistributedUnroller(DistributedOffPolicyAlgorithm): | ||
def __init__(self, | ||
core_alg_ctor: Callable, | ||
*args, | ||
episode_length: int = 200, | ||
env: AlfEnvironment = None, | ||
config: TrainerConfig = None, | ||
checkpoint: str = None, | ||
debug_summaries: bool = False, | ||
name: str = "DistributedUnroller", | ||
**kwargs): | ||
|
@@ -565,7 +573,14 @@ def __init__(self, | |
core_alg_ctor: creates the algorithm to be wrapped by this class. | ||
This algorithm's ``predict_step()`` and ``rollout_step()`` will | ||
be used for evaluation and rollout. | ||
checkpoint: this in-alg ckpt will be ignored if ``deploy_mode==False``. | ||
episode_length: the maximum number of experiences sent to one training | ||
worker before switching to the next one. If this arg<=0, the unroller | ||
will wait for ``StepType.LAST`` from the env before switching. It is the | ||
user's responsibility to make sure that the env returns ``StepType.LAST``. | ||
Otherwise, for every so many experiences, it will set the last exp | ||
step type to an artificial ``StepType.LAST``, and switching. | ||
For traing safety, it is recommended to always set this value to a | ||
positive number. | ||
*args: additional args to pass to ``core_alg_ctor``. | ||
**kwargs: additional kwargs to pass to ``core_alg_ctor``. | ||
""" | ||
|
@@ -579,17 +594,20 @@ def __init__(self, | |
_unroller_port_offset, 2 * _unroller_port_offset)), | ||
env=env, | ||
config=config, | ||
checkpoint=checkpoint, | ||
debug_summaries=debug_summaries, | ||
name=name, | ||
**kwargs) | ||
|
||
self._episode_length = episode_length | ||
self._num_exps = 0 | ||
self._is_first_step = True | ||
|
||
ip = get_local_ip() | ||
self._id = f"unroller-{ip}-{self._port}" | ||
|
||
# For sending experience data | ||
self._exp_socket, _ = create_zmq_socket(zmq.ROUTER, '*', self._port, | ||
self._id) | ||
# We will create it in a lazy way | ||
self._exp_socket = None | ||
|
||
# Record the current worker the data is being sent to | ||
# To maintain load balance, we want to cycle through the workers | ||
|
@@ -662,6 +680,35 @@ def observe_for_replay(self, exp: Experience): | |
Every time we make sure a full episode is sent to the same DDP rank, if | ||
multi-gpu training is enabled on the trainer. | ||
""" | ||
# Get the current worker id to send the exp to | ||
worker_id = f'worker-{self._current_worker}' | ||
self._num_exps += 1 | ||
episode_end = ((self._episode_length <= 0 and bool(exp.is_last())) | ||
or (self._num_exps % self._episode_length == 0)) | ||
|
||
if self._is_first_step: | ||
# When the unroller has a ``max_episode_length``, we need to correctly | ||
# set first time step type. | ||
if not episode_end: | ||
# In rare cases, the first step is also the last step, we don't | ||
# overwrite LAST to FIRST | ||
exp = alf.nest.set_field( | ||
exp, 'time_step.step_type', | ||
torch.tensor([StepType.FIRST], dtype=torch.int32)) | ||
self._is_first_step = False | ||
|
||
if episode_end: | ||
# One episode finishes; move to the next worker | ||
# We need to make sure a whole episode is always sent to the same | ||
# worker so that the temporal information is preserved. | ||
exp = alf.nest.set_field( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the case of a single trainer workers, we don't need to change the step type to LAST. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If there are multiple unrollers, we still need to set LAST. But it's not straightforward for an unroller to know if there is any other unroller, unless via the trainer. So for simplicity, here we always set LAST. |
||
exp, 'time_step.step_type', | ||
torch.tensor([StepType.LAST], dtype=torch.int32)) | ||
# Ask the trainer to dump to the replay buffer | ||
self._is_first_step = True | ||
self._current_worker = ( | ||
self._current_worker + 1) % self._num_trainer_workers | ||
|
||
# First prune exp's replay state to save communication overhead | ||
exp = alf.utils.common.prune_exp_replay_state( | ||
exp, self._use_rollout_state, self.rollout_state_spec, | ||
|
@@ -672,26 +719,21 @@ def observe_for_replay(self, exp: Experience): | |
buffer = io.BytesIO() | ||
torch.save(exp_params, buffer) | ||
|
||
worker_id = f'worker-{self._current_worker}' | ||
if self._exp_socket is None: | ||
self._exp_socket, _ = create_zmq_socket(zmq.ROUTER, '*', | ||
self._port, self._id) | ||
|
||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should send only for LAST step or episode length reached. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right now, we always send on a per-exp basis, instead of waiting for a long traj. The trainer is responsible for maintaining the traj integrity. The reason is for latency concern, because sending a very long traj might take a long time (especially with images), blocking the unroller. |
||
self._exp_socket.send_multipart([ | ||
worker_id.encode(), self._exp_socket.identity, | ||
buffer.getvalue() | ||
]) | ||
except zmq.error.ZMQError: # trainer is down | ||
pass | ||
|
||
if bool(exp.is_last()): | ||
# One episode finishes; move to the next worker | ||
# We need to make sure a whole episode is always sent to the same | ||
# worker so that the temporal information is preserved in its replay | ||
# buffer. | ||
self._exp_socket.send_multipart([ | ||
worker_id.encode(), self._exp_socket.identity, | ||
UnrollerMessage.EXP_SEG_END.encode() | ||
]) | ||
self._current_worker = ( | ||
self._current_worker + 1) % self._num_trainer_workers | ||
except zmq.error.ZMQError: | ||
# Trainer is down. | ||
# We might want to keep running the unroller but restart a trainer later. | ||
logging.warning( | ||
f"Trainer {worker_id} is not reachable. Skip sending exp to it." | ||
) | ||
|
||
def _check_paramss_update(self) -> bool: | ||
"""Returns True if params have been updated. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the batch size of the replay buffer is 1. env_id has to be 0 at the next line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is true for the current assumption. But since exp_params always contains env_id, we can just use it. Do you mean we should assert it's equal to 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just set it to 0 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated