-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
save videos during training #597
base: master
Are you sure you want to change the base?
Changes from 3 commits
c41345b
a4211ff
7f5b803
5e6ab25
9751248
b79d863
9074f3c
db23103
1b1c990
720e245
63850dd
9335d32
7714484
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,4 +177,4 @@ def make_venv( | |
try: | ||
yield venv | ||
finally: | ||
venv.close() | ||
venv.close() | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,8 @@ | |
from imitation.data import rollout | ||
from imitation.policies import base | ||
from imitation.scripts.common import common | ||
from imitation.util import video_wrapper | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused import? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed now! |
||
|
||
|
||
train_ingredient = sacred.Ingredient("train", ingredients=[common.common_ingredient]) | ||
logger = logging.getLogger(__name__) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
import torch as th | ||
from sacred.observers import FileStorageObserver | ||
|
||
import imitation.util.video_wrapper as video_wrapper | ||
from imitation.algorithms.adversarial import airl as airl_algo | ||
from imitation.algorithms.adversarial import common | ||
from imitation.algorithms.adversarial import gail as gail_algo | ||
|
@@ -111,9 +112,18 @@ def train_adversarial( | |
sacred.commands.print_config(_run) | ||
|
||
custom_logger, log_dir = common_config.setup_logging() | ||
checkpoint_dir = log_dir / "checkpoints" | ||
AdamGleave marked this conversation as resolved.
Show resolved
Hide resolved
|
||
video_dir = checkpoint_dir / "videos" | ||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | ||
video_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
expert_trajs = demonstrations.get_expert_trajectories() | ||
|
||
with common_config.make_venv() as venv: | ||
post_wrappers = None | ||
if checkpoint_interval > 0: | ||
post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to with preference comparisons, unfortunately There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be now fixed with a separate video_save_interval parameter |
||
|
||
with common_config.make_venv(post_wrappers=post_wrappers) as venv: | ||
reward_net = reward.make_reward_net(venv) | ||
relabel_reward_fn = functools.partial( | ||
reward_net.predict_processed, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ | |
from sacred.observers import FileStorageObserver | ||
from stable_baselines3.common import type_aliases | ||
|
||
import gym | ||
import imitation.util.video_wrapper as video_wrapper | ||
from imitation.algorithms import preference_comparisons | ||
from imitation.data import types | ||
from imitation.policies import serialize | ||
|
@@ -21,6 +23,8 @@ | |
from imitation.scripts.config.train_preference_comparisons import ( | ||
train_preference_comparisons_ex, | ||
) | ||
import imitation.util.video_wrapper as video_wrapper | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. duplicate import? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch -- fixed |
||
|
||
|
||
|
||
def save_model( | ||
|
@@ -149,14 +153,24 @@ def train_preference_comparisons( | |
ValueError: Inconsistency between config and deserialized policy normalization. | ||
""" | ||
custom_logger, log_dir = common.setup_logging() | ||
checkpoint_dir = log_dir / "checkpoints" | ||
AdamGleave marked this conversation as resolved.
Show resolved
Hide resolved
|
||
video_dir = checkpoint_dir / "videos" | ||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | ||
video_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
rng = common.make_rng() | ||
|
||
with common.make_venv() as venv: | ||
post_wrappers = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code (and the logging setup above) is basically identical to that of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea -- create a function in common to do this now |
||
if checkpoint_interval > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm maybe weakly leaning towards 3) because that'll make it easier to decouple this from each individual script and just factor it into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I definitely agree with point 3 there -- should be now fixed with a separate video_save_interval parameter |
||
post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This always enables video generation when checkpointing is enabled. I think video generation should be toggleable by another flag. Generating videos often requires installing extra dependencies and maybe running an X server in the background as well having a performance overhead. So it's not something everyone wants (I'd probably lean towards having it off by default actually, but can see an argument either way). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be fixed now |
||
|
||
with common.make_venv(post_wrappers=post_wrappers) as venv: | ||
reward_net = reward.make_reward_net(venv) | ||
relabel_reward_fn = functools.partial( | ||
reward_net.predict_processed, | ||
update_stats=False, | ||
) | ||
|
||
if agent_path is None: | ||
agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn) | ||
else: | ||
|
@@ -287,6 +301,5 @@ def main_console(): | |
train_preference_comparisons_ex.observers.append(observer) | ||
train_preference_comparisons_ex.run_commandline() | ||
|
||
|
||
if __name__ == "__main__": # pragma: no cover | ||
main_console() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
"""Uses RL to train a policy from scratch, saving rollouts and policy. | ||
|
||
This can be used: | ||
1. To train a policy on a ground-truth reward function, as a source of | ||
synthetic "expert" demonstrations to train IRL or imitation learning | ||
|
@@ -17,6 +16,7 @@ | |
from stable_baselines3.common import callbacks | ||
from stable_baselines3.common.vec_env import VecNormalize | ||
|
||
import imitation.util.video_wrapper as video_wrapper | ||
from imitation.data import rollout, types, wrappers | ||
from imitation.policies import serialize | ||
from imitation.rewards.reward_wrapper import RewardVecEnvWrapper | ||
|
@@ -41,15 +41,13 @@ def train_rl( | |
policy_save_final: bool, | ||
agent_path: Optional[str], | ||
) -> Mapping[str, float]: | ||
"""Trains an expert policy from scratch and saves the rollouts and policy. | ||
|
||
"""Trains an expert policy from scratch and saves the rollouts and policy. | ||
Checkpoints: | ||
At applicable training steps `step` (where step is either an integer or | ||
"final"): | ||
|
||
- Policies are saved to `{log_dir}/policies/{step}/`. | ||
- Rollouts are saved to `{log_dir}/rollouts/{step}.npz`. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want to keep this whitespace between sections to be compliant with Google docstring style which we've adopted in this project: https://google.github.io/styleguide/pyguide.html#doc-function-raises Removing whitespace before the list above seems fine though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed now |
||
Args: | ||
total_timesteps: Number of training timesteps in `model.learn()`. | ||
normalize_reward: Applies normalization and clipping to the reward function by | ||
|
@@ -82,18 +80,25 @@ def train_rl( | |
policy_save_final: If True, then save the policy right after training is | ||
finished. | ||
agent_path: Path to load warm-started agent. | ||
|
||
Returns: | ||
The return value of `rollout_stats()` using the final policy. | ||
""" | ||
rng = common.make_rng() | ||
custom_logger, log_dir = common.setup_logging() | ||
rollout_dir = log_dir / "rollouts" | ||
policy_dir = log_dir / "policies" | ||
video_dir = log_dir / "videos" | ||
rollout_dir.mkdir(parents=True, exist_ok=True) | ||
policy_dir.mkdir(parents=True, exist_ok=True) | ||
video_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] | ||
|
||
if policy_save_interval > 0: | ||
post_wrappers.append( | ||
video_wrapper.video_wrapper_factory(video_dir, policy_save_interval) | ||
) | ||
|
||
with common.make_venv(post_wrappers=post_wrappers) as venv: | ||
callback_objs = [] | ||
if reward_type is not None: | ||
|
@@ -164,4 +169,4 @@ def main_console(): | |
|
||
|
||
if __name__ == "__main__": # pragma: no cover | ||
main_console() | ||
main_console() |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,8 +1,7 @@ | ||||||
"""Wrapper to record rendered video frames from an environment.""" | ||||||
|
||||||
import pathlib | ||||||
from typing import Optional | ||||||
|
||||||
from typing import Optional, Callable | ||||||
import gym | ||||||
from gym.wrappers.monitoring import video_recorder | ||||||
|
||||||
|
@@ -14,12 +13,16 @@ class VideoWrapper(gym.Wrapper): | |||||
video_recorder: Optional[video_recorder.VideoRecorder] | ||||||
single_video: bool | ||||||
directory: pathlib.Path | ||||||
cadence: int | ||||||
should_record: bool | ||||||
step_count: int | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
env: gym.Env, | ||||||
directory: pathlib.Path, | ||||||
single_video: bool = True, | ||||||
cadence: int = 1, | ||||||
): | ||||||
"""Builds a VideoWrapper. | ||||||
|
||||||
|
@@ -31,14 +34,22 @@ def __init__( | |||||
Usually a single video file is what is desired. However, if one is | ||||||
searching for an interesting episode (perhaps by looking at the | ||||||
metadata), then saving to different files can be useful. | ||||||
cadence: the video wrapper will save a video of the next episode that | ||||||
begins after every Nth step. So if cadence=100 and each episode has | ||||||
30 steps, it will record the 4th episode(first to start after | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
step_count=100) and then the 7thepisode (first to start after | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed! |
||||||
step_count=200). | ||||||
""" | ||||||
super().__init__(env) | ||||||
self.episode_id = 0 | ||||||
self.video_recorder = None | ||||||
self.single_video = single_video | ||||||
self.cadence = cadence | ||||||
|
||||||
self.directory = directory | ||||||
self.directory.mkdir(parents=True, exist_ok=True) | ||||||
self.should_record = False | ||||||
self.step_count = 0 | ||||||
|
||||||
def _reset_video_recorder(self) -> None: | ||||||
"""Creates a video recorder if one does not already exist. | ||||||
|
@@ -53,13 +64,14 @@ def _reset_video_recorder(self) -> None: | |||||
self.video_recorder.close() | ||||||
self.video_recorder = None | ||||||
|
||||||
if self.video_recorder is None: | ||||||
if self.video_recorder is None and (self.should_record or self.step_count % self.cadence == 0): | ||||||
# No video recorder -- start a new one. | ||||||
self.video_recorder = video_recorder.VideoRecorder( | ||||||
env=self.env, | ||||||
base_path=str(self.directory / f"video.{self.episode_id:06}"), | ||||||
metadata={"episode_id": self.episode_id}, | ||||||
) | ||||||
self.should_record = False | ||||||
|
||||||
def reset(self): | ||||||
self._reset_video_recorder() | ||||||
|
@@ -68,11 +80,30 @@ def reset(self): | |||||
|
||||||
def step(self, action): | ||||||
res = self.env.step(action) | ||||||
self.video_recorder.capture_frame() | ||||||
self.step_count += 1 | ||||||
if self.step_count % self.cadence == 0: | ||||||
self.should_record == 0 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line doesn't do anything (equality check) -- was this meant to be an assignment? Also |
||||||
if self.video_recorder != None: | ||||||
self.video_recorder.capture_frame() | ||||||
return res | ||||||
|
||||||
def close(self) -> None: | ||||||
if self.video_recorder is not None: | ||||||
self.video_recorder.close() | ||||||
self.video_recorder = None | ||||||
super().close() | ||||||
|
||||||
|
||||||
def video_wrapper_factory(video_dir: pathlib.Path, cadence: int, **kwargs) -> Callable: | ||||||
def f(env: gym.Env, i: int) -> VideoWrapper: | ||||||
""" | ||||||
Returns a wrapper around a gym environment records a video if and only if i is 0 | ||||||
|
||||||
Args: | ||||||
env: the environment to be wrapped around | ||||||
i: the index of the environment. This is to make the video wrapper compatible with | ||||||
vectorized environments. Only environments with i=0 actually attach the VideoWrapper | ||||||
""" | ||||||
|
||||||
return VideoWrapper(env, directory=video_dir, cadence=cadence, **kwargs) if i == 0 else env | ||||||
return f |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
"""Smoke tests for CLI programs in `imitation.scripts.*`. | ||
|
||
Every test in this file should use `parallel=False` to turn off multiprocessing because | ||
codecov might interact poorly with multiprocessing. The 'fast' named_config for each | ||
experiment implicitly sets parallel=False. | ||
|
@@ -73,10 +72,9 @@ | |
|
||
@pytest.fixture(autouse=True) | ||
def sacred_capture_use_sys(): | ||
"""Set Sacred capture mode to "sys" because default "fd" option leads to error. | ||
|
||
"""Set Sacred capture mode to "sys" because default "fd" option leads to error. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why extra whitespace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The CircleCI linter was putting a warning on this -- I'll remove all the changes that weren't to functions I created/was working on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's odd, most of our functions do not have whitespace before docstring, I am not sure why this would cause a linter error. Is it this?
I think that's complaining there's not a newline following this line, i.e. it's expecting docstrings in format:
|
||
See https://github.com/IDSIA/sacred/issues/289. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to remove this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
Yields: | ||
None after setting capture mode; restores it after yield. | ||
""" | ||
|
@@ -602,9 +600,7 @@ def test_train_adversarial_algorithm_value_error(tmpdir): | |
|
||
def test_transfer_learning(tmpdir: str) -> None: | ||
"""Transfer learning smoke test. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to remove this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
Saves a dummy AIRL test reward, then loads it for transfer learning. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or this? |
||
Args: | ||
tmpdir: Temporary directory to save results to. | ||
""" | ||
|
@@ -649,10 +645,9 @@ def test_preference_comparisons_transfer_learning( | |
tmpdir: str, | ||
named_configs_dict: Mapping[str, List[str]], | ||
) -> None: | ||
"""Transfer learning smoke test. | ||
|
||
"""Transfer learning smoke test. | ||
AdamGleave marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Saves a preference comparisons ensemble reward, then loads it for transfer learning. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to remove this? |
||
Args: | ||
tmpdir: Temporary directory to save results to. | ||
named_configs_dict: Named configs for preference_comparisons and rl. | ||
|
@@ -953,3 +948,57 @@ def test_convert_trajs(tmpdir: str): | |
assert len(from_pkl) == len(from_npz) | ||
for t_pkl, t_npz in zip(from_pkl, from_npz): | ||
assert t_pkl == t_npz | ||
|
||
# Change the following if the file structure of checkpoints changed. | ||
VIDEO_FILE_PATH = "video.{:06}.mp4".format(0) | ||
VIDEO_PATH_DICT = dict( | ||
rl=lambda d: d / "videos", | ||
adversarial=lambda d: d / "checkpoints" / "videos", | ||
pc=lambda d: d / "checkpoints" / "videos" | ||
) | ||
|
||
def _check_video_exists(log_dir, algo): | ||
video_dir = VIDEO_PATH_DICT[algo](log_dir) | ||
assert os.path.exists(video_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just updated! |
||
assert VIDEO_FILE_PATH in os.listdir(video_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (I'd guess there's a pathlib version of this too but not sure.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be updated now |
||
|
||
def test_train_rl_video_saving(tmpdir): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a lot of duplication here with |
||
"""Smoke test for imitation.scripts.train_rl.""" | ||
config_updates = dict( | ||
common=dict(log_root=tmpdir) | ||
) | ||
run = train_rl.train_rl_ex.run( | ||
named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"], | ||
config_updates=config_updates, | ||
) | ||
|
||
assert run.status == "COMPLETED" | ||
_check_video_exists(run.config["common"]["log_dir"], "rl") | ||
|
||
def test_train_adversarial_video_saving(tmpdir): | ||
"""Smoke test for imitation.scripts.train_adversarial.""" | ||
named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"] | ||
config_updates = dict( | ||
common=dict(log_root=tmpdir), | ||
demonstrations=dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), | ||
checkpoint_interval=1 | ||
) | ||
run = train_adversarial.train_adversarial_ex.run( | ||
command_name="gail", | ||
named_configs=named_configs, | ||
config_updates=config_updates, | ||
) | ||
assert run.status == "COMPLETED" | ||
_check_video_exists(run.config["common"]["log_dir"], "adversarial") | ||
|
||
def test_train_preference_comparisons_video_saving(tmpdir): | ||
config_updates = dict( | ||
common=dict(log_root=tmpdir), | ||
checkpoint_interval=1 | ||
) | ||
run = train_preference_comparisons.train_preference_comparisons_ex.run( | ||
named_configs=["cartpole"] + ALGO_FAST_CONFIGS["preference_comparison"], | ||
config_updates=config_updates, | ||
) | ||
assert run.status == "COMPLETED" | ||
_check_video_exists(run.config["common"]["log_dir"], "pc") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,6 +113,11 @@ def test_wandb_output_format(): | |
{"_step": 0, "foo": 42, "fizz": 12}, | ||
{"_step": 3, "fizz": 21}, | ||
] | ||
|
||
with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Testing invalid input error handling is good but it's a bit odd we're not also testing that it does the right thing with a valid input? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Testing the basic saving features already exists in a previous test (since manual video saving is already supported) -- I just added this test since it was in the original PR and figured it couldn't hurt. |
||
log_obj.record("video", 42) | ||
log_obj.dump(step=4) | ||
|
||
log_obj.close() | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this change (removing newline) intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope -- fixing now