Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save videos during training #597

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/imitation/scripts/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,4 @@ def make_venv(
try:
yield venv
finally:
venv.close()
venv.close()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this change (removing newline) intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope -- fixing now

1 change: 1 addition & 0 deletions src/imitation/scripts/common/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def make_rl_algo(
)
else:
raise TypeError(f"Unsupported RL algorithm '{rl_cls}'")

rl_algo = rl_cls(
policy=train["policy_cls"],
# Note(yawen): Copy `policy_kwargs` as SB3 may mutate the config we pass.
Expand Down
2 changes: 2 additions & 0 deletions src/imitation/scripts/common/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from imitation.data import rollout
from imitation.policies import base
from imitation.scripts.common import common
from imitation.util import video_wrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused import?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now!



train_ingredient = sacred.Ingredient("train", ingredients=[common.common_ingredient])
logger = logging.getLogger(__name__)
Expand Down
12 changes: 11 additions & 1 deletion src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch as th
from sacred.observers import FileStorageObserver

import imitation.util.video_wrapper as video_wrapper
from imitation.algorithms.adversarial import airl as airl_algo
from imitation.algorithms.adversarial import common
from imitation.algorithms.adversarial import gail as gail_algo
Expand Down Expand Up @@ -111,9 +112,18 @@ def train_adversarial(
sacred.commands.print_config(_run)

custom_logger, log_dir = common_config.setup_logging()
checkpoint_dir = log_dir / "checkpoints"
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
video_dir = checkpoint_dir / "videos"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
video_dir.mkdir(parents=True, exist_ok=True)

expert_trajs = demonstrations.get_expert_trajectories()

with common_config.make_venv() as venv:
post_wrappers = None
if checkpoint_interval > 0:
post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to with preference comparisons, unfortunately checkpoint_interval is not the same thing as timesteps, so some conversion or other solution needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be now fixed with a separate video_save_interval parameter


with common_config.make_venv(post_wrappers=post_wrappers) as venv:
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
Expand Down
17 changes: 15 additions & 2 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from sacred.observers import FileStorageObserver
from stable_baselines3.common import type_aliases

import gym
import imitation.util.video_wrapper as video_wrapper
from imitation.algorithms import preference_comparisons
from imitation.data import types
from imitation.policies import serialize
Expand All @@ -21,6 +23,8 @@
from imitation.scripts.config.train_preference_comparisons import (
train_preference_comparisons_ex,
)
import imitation.util.video_wrapper as video_wrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate import?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch -- fixed




def save_model(
Expand Down Expand Up @@ -149,14 +153,24 @@ def train_preference_comparisons(
ValueError: Inconsistency between config and deserialized policy normalization.
"""
custom_logger, log_dir = common.setup_logging()
checkpoint_dir = log_dir / "checkpoints"
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
video_dir = checkpoint_dir / "videos"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
video_dir.mkdir(parents=True, exist_ok=True)

rng = common.make_rng()

with common.make_venv() as venv:
post_wrappers = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code (and the logging setup above) is basically identical to that of train_adversarial. This makes me wonder if it should be factored out into the common Sacred ingredient? You could just add it to part of make_venv. You'd probably need to pass in the checkpoint_interval (could default to None), but everything else it needs I think is already captured in the scope.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea -- create a function in common to do this now

if checkpoint_interval > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint_interval here I think is in terms of number of iterations of the algorithm, which is much smaller than number of timesteps of RL interaction. Sorry if I misled you earlier in our call here. I think options here are either:

  1. Do the conversion from checkpoint_interval to total timesteps.
  2. Figure out another way to trigger it at each checkpoint.
  3. Don't try to record it at each checkpoint, and just have a separate config flag for video_interval.

I'm maybe weakly leaning towards 3) because that'll make it easier to decouple this from each individual script and just factor it into common, you could also combine that with 2) if you just expose a function that instructs video wrapper to save the next episode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely agree with point 3 there -- should be now fixed with a separate video_save_interval parameter

post_wrappers = [video_wrapper.video_wrapper_factory(video_dir, checkpoint_interval)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This always enables video generation when checkpointing is enabled. I think video generation should be toggleable by another flag. Generating videos often requires installing extra dependencies and maybe running an X server in the background as well having a performance overhead. So it's not something everyone wants (I'd probably lean towards having it off by default actually, but can see an argument either way).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed now


with common.make_venv(post_wrappers=post_wrappers) as venv:
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
update_stats=False,
)

if agent_path is None:
agent = rl_common.make_rl_algo(venv, relabel_reward_fn=relabel_reward_fn)
else:
Expand Down Expand Up @@ -287,6 +301,5 @@ def main_console():
train_preference_comparisons_ex.observers.append(observer)
train_preference_comparisons_ex.run_commandline()


if __name__ == "__main__": # pragma: no cover
main_console()
17 changes: 11 additions & 6 deletions src/imitation/scripts/train_rl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Uses RL to train a policy from scratch, saving rollouts and policy.

This can be used:
1. To train a policy on a ground-truth reward function, as a source of
synthetic "expert" demonstrations to train IRL or imitation learning
Expand All @@ -17,6 +16,7 @@
from stable_baselines3.common import callbacks
from stable_baselines3.common.vec_env import VecNormalize

import imitation.util.video_wrapper as video_wrapper
from imitation.data import rollout, types, wrappers
from imitation.policies import serialize
from imitation.rewards.reward_wrapper import RewardVecEnvWrapper
Expand All @@ -41,15 +41,13 @@ def train_rl(
policy_save_final: bool,
agent_path: Optional[str],
) -> Mapping[str, float]:
"""Trains an expert policy from scratch and saves the rollouts and policy.

"""Trains an expert policy from scratch and saves the rollouts and policy.
Checkpoints:
At applicable training steps `step` (where step is either an integer or
"final"):

- Policies are saved to `{log_dir}/policies/{step}/`.
- Rollouts are saved to `{log_dir}/rollouts/{step}.npz`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to keep this whitespace between sections to be compliant with Google docstring style which we've adopted in this project: https://google.github.io/styleguide/pyguide.html#doc-function-raises

Removing whitespace before the list above seems fine though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now

Args:
total_timesteps: Number of training timesteps in `model.learn()`.
normalize_reward: Applies normalization and clipping to the reward function by
Expand Down Expand Up @@ -82,18 +80,25 @@ def train_rl(
policy_save_final: If True, then save the policy right after training is
finished.
agent_path: Path to load warm-started agent.

Returns:
The return value of `rollout_stats()` using the final policy.
"""
rng = common.make_rng()
custom_logger, log_dir = common.setup_logging()
rollout_dir = log_dir / "rollouts"
policy_dir = log_dir / "policies"
video_dir = log_dir / "videos"
rollout_dir.mkdir(parents=True, exist_ok=True)
policy_dir.mkdir(parents=True, exist_ok=True)
video_dir.mkdir(parents=True, exist_ok=True)

post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)]

if policy_save_interval > 0:
post_wrappers.append(
video_wrapper.video_wrapper_factory(video_dir, policy_save_interval)
)

with common.make_venv(post_wrappers=post_wrappers) as venv:
callback_objs = []
if reward_type is not None:
Expand Down Expand Up @@ -164,4 +169,4 @@ def main_console():


if __name__ == "__main__": # pragma: no cover
main_console()
main_console()
5 changes: 4 additions & 1 deletion src/imitation/util/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,10 @@ def write(
if excluded is not None and "wandb" in excluded:
continue

self.wandb_module.log({key: value}, step=step)
if key != "video":
self.wandb_module.log({key: value}, step=step)
else:
self.wandb_module.log({"video": self.wandb_module.Video(value)})
self.wandb_module.log({}, commit=True)

def close(self) -> None:
Expand Down
39 changes: 35 additions & 4 deletions src/imitation/util/video_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Wrapper to record rendered video frames from an environment."""

import pathlib
from typing import Optional

from typing import Optional, Callable
import gym
from gym.wrappers.monitoring import video_recorder

Expand All @@ -14,12 +13,16 @@ class VideoWrapper(gym.Wrapper):
video_recorder: Optional[video_recorder.VideoRecorder]
single_video: bool
directory: pathlib.Path
cadence: int
should_record: bool
step_count: int

def __init__(
self,
env: gym.Env,
directory: pathlib.Path,
single_video: bool = True,
cadence: int = 1,
):
"""Builds a VideoWrapper.

Expand All @@ -31,14 +34,22 @@ def __init__(
Usually a single video file is what is desired. However, if one is
searching for an interesting episode (perhaps by looking at the
metadata), then saving to different files can be useful.
cadence: the video wrapper will save a video of the next episode that
begins after every Nth step. So if cadence=100 and each episode has
30 steps, it will record the 4th episode(first to start after
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
30 steps, it will record the 4th episode(first to start after
30 steps, it will record the 4th episode (first to start after

step_count=100) and then the 7thepisode (first to start after
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
step_count=100) and then the 7thepisode (first to start after
step_count=100) and then the 7th episode (first to start after

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

step_count=200).
"""
super().__init__(env)
self.episode_id = 0
self.video_recorder = None
self.single_video = single_video
self.cadence = cadence

self.directory = directory
self.directory.mkdir(parents=True, exist_ok=True)
self.should_record = False
self.step_count = 0

def _reset_video_recorder(self) -> None:
"""Creates a video recorder if one does not already exist.
Expand All @@ -53,13 +64,14 @@ def _reset_video_recorder(self) -> None:
self.video_recorder.close()
self.video_recorder = None

if self.video_recorder is None:
if self.video_recorder is None and (self.should_record or self.step_count % self.cadence == 0):
# No video recorder -- start a new one.
self.video_recorder = video_recorder.VideoRecorder(
env=self.env,
base_path=str(self.directory / f"video.{self.episode_id:06}"),
metadata={"episode_id": self.episode_id},
)
self.should_record = False

def reset(self):
self._reset_video_recorder()
Expand All @@ -68,11 +80,30 @@ def reset(self):

def step(self, action):
res = self.env.step(action)
self.video_recorder.capture_frame()
self.step_count += 1
if self.step_count % self.cadence == 0:
self.should_record == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line doesn't do anything (equality check) -- was this meant to be an assignment? Also self.should_record has type bool not int.

if self.video_recorder != None:
self.video_recorder.capture_frame()
return res

def close(self) -> None:
if self.video_recorder is not None:
self.video_recorder.close()
self.video_recorder = None
super().close()


def video_wrapper_factory(video_dir: pathlib.Path, cadence: int, **kwargs) -> Callable:
def f(env: gym.Env, i: int) -> VideoWrapper:
"""
Returns a wrapper around a gym environment records a video if and only if i is 0

Args:
env: the environment to be wrapped around
i: the index of the environment. This is to make the video wrapper compatible with
vectorized environments. Only environments with i=0 actually attach the VideoWrapper
"""

return VideoWrapper(env, directory=video_dir, cadence=cadence, **kwargs) if i == 0 else env
return f
63 changes: 56 additions & 7 deletions tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Smoke tests for CLI programs in `imitation.scripts.*`.

Every test in this file should use `parallel=False` to turn off multiprocessing because
codecov might interact poorly with multiprocessing. The 'fast' named_config for each
experiment implicitly sets parallel=False.
Expand Down Expand Up @@ -73,10 +72,9 @@

@pytest.fixture(autouse=True)
def sacred_capture_use_sys():
"""Set Sacred capture mode to "sys" because default "fd" option leads to error.

"""Set Sacred capture mode to "sys" because default "fd" option leads to error.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why extra whitespace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CircleCI linter was putting a warning on this -- I'll remove all the changes that weren't to functions I created/was working on

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's odd, most of our functions do not have whitespace before docstring, I am not sure why this would cause a linter error.

Is it this?

tests/scripts/test_scripts.py:1:1: D205 1 blank line required between summary line and description

I think that's complaining there's not a newline following this line, i.e. it's expecting docstrings in format:

A short one sentence summary.

Optionally, a more elaborate description of what this function does.
Some details only for the astute reader.

Args:
    foo: ...

See https://github.com/IDSIA/sacred/issues/289.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Yields:
None after setting capture mode; restores it after yield.
"""
Expand Down Expand Up @@ -602,9 +600,7 @@ def test_train_adversarial_algorithm_value_error(tmpdir):

def test_transfer_learning(tmpdir: str) -> None:
"""Transfer learning smoke test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Saves a dummy AIRL test reward, then loads it for transfer learning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or this?

Args:
tmpdir: Temporary directory to save results to.
"""
Expand Down Expand Up @@ -649,10 +645,9 @@ def test_preference_comparisons_transfer_learning(
tmpdir: str,
named_configs_dict: Mapping[str, List[str]],
) -> None:
"""Transfer learning smoke test.

"""Transfer learning smoke test.
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
Saves a preference comparisons ensemble reward, then loads it for transfer learning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to remove this?

Args:
tmpdir: Temporary directory to save results to.
named_configs_dict: Named configs for preference_comparisons and rl.
Expand Down Expand Up @@ -953,3 +948,57 @@ def test_convert_trajs(tmpdir: str):
assert len(from_pkl) == len(from_npz)
for t_pkl, t_npz in zip(from_pkl, from_npz):
assert t_pkl == t_npz

# Change the following if the file structure of checkpoints changed.
VIDEO_FILE_PATH = "video.{:06}.mp4".format(0)
VIDEO_PATH_DICT = dict(
rl=lambda d: d / "videos",
adversarial=lambda d: d / "checkpoints" / "videos",
pc=lambda d: d / "checkpoints" / "videos"
)

def _check_video_exists(log_dir, algo):
video_dir = VIDEO_PATH_DICT[algo](log_dir)
assert os.path.exists(video_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If video_dir is a Pathlib.path I think just video_dir.exists() works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just updated!

assert VIDEO_FILE_PATH in os.listdir(video_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'd guess there's a pathlib version of this too but not sure.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be updated now


def test_train_rl_video_saving(tmpdir):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of duplication here with test_train_rl_main, perhaps we can combine them somehow? This comment also applies to some extent to test_train_adversarial_* and test_train_preference_comparisons_* below.

"""Smoke test for imitation.scripts.train_rl."""
config_updates = dict(
common=dict(log_root=tmpdir)
)
run = train_rl.train_rl_ex.run(
named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"],
config_updates=config_updates,
)

assert run.status == "COMPLETED"
_check_video_exists(run.config["common"]["log_dir"], "rl")

def test_train_adversarial_video_saving(tmpdir):
"""Smoke test for imitation.scripts.train_adversarial."""
named_configs = ["pendulum"] + ALGO_FAST_CONFIGS["adversarial"]
config_updates = dict(
common=dict(log_root=tmpdir),
demonstrations=dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH),
checkpoint_interval=1
)
run = train_adversarial.train_adversarial_ex.run(
command_name="gail",
named_configs=named_configs,
config_updates=config_updates,
)
assert run.status == "COMPLETED"
_check_video_exists(run.config["common"]["log_dir"], "adversarial")

def test_train_preference_comparisons_video_saving(tmpdir):
config_updates = dict(
common=dict(log_root=tmpdir),
checkpoint_interval=1
)
run = train_preference_comparisons.train_preference_comparisons_ex.run(
named_configs=["cartpole"] + ALGO_FAST_CONFIGS["preference_comparison"],
config_updates=config_updates,
)
assert run.status == "COMPLETED"
_check_video_exists(run.config["common"]["log_dir"], "pc")
5 changes: 5 additions & 0 deletions tests/util/test_wb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def test_wandb_output_format():
{"_step": 0, "foo": 42, "fizz": 12},
{"_step": 3, "fizz": 21},
]

with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing invalid input error handling is good but it's a bit odd we're not also testing that it does the right thing with a valid input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing the basic saving features already exists in a previous test (since manual video saving is already supported) -- I just added this test since it was in the original PR and figured it couldn't hurt.

log_obj.record("video", 42)
log_obj.dump(step=4)

log_obj.close()


Expand Down