Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented the ability to train rewards in preference comparison against multiple policies #529

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a7eda89
implemented MixtureOfTrajectoryGeneratos
levmckinney Aug 9, 2022
1a9ab5d
added add_prefix method to HierarchicalLogger
levmckinney Aug 17, 2022
d442baa
moved logger.accumulate_means into AgentTrainer
levmckinney Aug 17, 2022
003bbcd
added option for multiple agents to train_preference_comparison
levmckinney Aug 17, 2022
f6a3d44
reduced the length of the prefix used in mixture of generators for lo…
levmckinney Aug 18, 2022
33110aa
Merge branch 'master' into policy_ensemble
levmckinney Aug 18, 2022
8c5433d
improved logic for whether to try and checkpoint policy
levmckinney Aug 18, 2022
83f8b8b
fixed final save checkpoint call
levmckinney Aug 18, 2022
803ffde
fixed log key when using prefix on windows
levmckinney Aug 18, 2022
21e11fe
clarified doc string and added runtime error
levmckinney Aug 22, 2022
5c9fc1e
responded to reviewers comments
levmckinney Aug 22, 2022
e478ec3
fixed logic bug
levmckinney Aug 23, 2022
75c0b80
fixed test
levmckinney Aug 23, 2022
73ace03
added pragma: no cover to test case which needs it
levmckinney Aug 23, 2022
b0b3b93
added doctest to logger explaining behavour
levmckinney Aug 23, 2022
dc60be8
Merge branch 'master' into policy_ensemble
levmckinney Aug 24, 2022
2958d08
added option to split training steps among the agents and made it def…
levmckinney Aug 24, 2022
303b032
Add type annotations to hierarchical logger
Rocamonde Sep 3, 2022
230774c
Move "is single agent" to explicit bool definition
Rocamonde Sep 3, 2022
d44cc85
Raise error when too few steps to partition.
Rocamonde Sep 3, 2022
07fa830
Formatter
Rocamonde Sep 3, 2022
dd72656
Fix "else" that was accidentally removed.
Rocamonde Sep 3, 2022
709cb42
Roll back change
Rocamonde Sep 3, 2022
06084ad
Added description to tests
levmckinney Sep 5, 2022
a2b790f
Unnested with statements
levmckinney Sep 5, 2022
510abd8
removed duplicate sentence in documentation
levmckinney Sep 5, 2022
232b56e
Apply suggestions from code review
levmckinney Sep 5, 2022
09670d3
Merge branch 'master' into policy_ensemble
levmckinney Sep 6, 2022
a5ab914
Merge branch 'master' into policy_ensemble
levmckinney Sep 13, 2022
408e248
fixed comment in doctest
levmckinney Sep 13, 2022
4df8551
Update src/imitation/algorithms/preference_comparisons.py
Rocamonde Sep 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 93 additions & 14 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Trains a reward model and optionally a policy based on preferences
between trajectory fragments.
"""
from __future__ import generators

import abc
import math
import pickle
Expand Down Expand Up @@ -210,18 +212,19 @@ def train(self, steps: int, **kwargs) -> None:
RuntimeError: Transitions left in `self.buffering_wrapper`; call
`self.sample` first to clear them.
"""
n_transitions = self.buffering_wrapper.n_transitions
if n_transitions:
raise RuntimeError(
f"There are {n_transitions} transitions left in the buffer. "
"Call AgentTrainer.sample() first to clear them.",
with self.logger.accumulate_means("agent"):
n_transitions = self.buffering_wrapper.n_transitions
if n_transitions:
raise RuntimeError(
f"There are {n_transitions} transitions left in the buffer. "
"Call AgentTrainer.sample() first to clear them.",
)
self.algorithm.learn(
total_timesteps=steps,
reset_num_timesteps=False,
callback=self.log_callback,
**kwargs,
)
self.algorithm.learn(
total_timesteps=steps,
reset_num_timesteps=False,
callback=self.log_callback,
**kwargs,
)

def sample(self, steps: int) -> Sequence[types.TrajectoryWithRew]:
agent_trajs, _ = self.buffering_wrapper.pop_finished_trajectories()
Expand Down Expand Up @@ -299,6 +302,83 @@ def logger(self, value: imit_logger.HierarchicalLogger):
self.algorithm.set_logger(self.logger)


class MixtureOfTrajectoryGenerators(TrajectoryGenerator):
"""A collection of trajectory generators merged together."""

members: Sequence[TrajectoryGenerator]

def __init__(
self,
members: Sequence[TrajectoryGenerator],
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
):
"""Create a mixture of trajectory generators.

Args:
members: Individual trajectory generators that will make up the ensemble.
custom_logger: Custom logger passed to super class.

Raises:
ValueError: if members is empty.
"""
if len(members) == 0:
raise ValueError(
"MixtureOfTrajectoryGenerators requires at least one member!",
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
)
self.members = tuple(members)
super().__init__(custom_logger=custom_logger)

def sample(self, steps: int) -> Sequence[TrajectoryWithRew]:
"""Sample a batch of trajectories splitting evenly amongst the mixture members.

Args:
steps: All trajectories taken together should
have at least this many steps.

Returns:
A list of sampled trajectories with rewards (which should
be the environment rewards, not ones from a reward model).
"""
n = len(self.members)
# Approximately evenly partition work.
d = steps // n
r = steps % n
i = np.random.randint(n)
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
partition = [d] * n
partition[i] += r
trajectories = []

for s, generator in zip(partition, self.members):
trajectories.extend(generator.sample(s))

return trajectories

def train(self, steps: int, **kwargs):
"""Train an agent if the trajectory generator uses one.

By default, this method does nothing and doesn't need
to be overridden in subclasses that don't require training.

Args:
steps: number of environment steps to train for.
**kwargs: additional keyword arguments to pass on to
the training procedure.
"""
for i, generator in enumerate(self.members):
with self.logger.add_prefix(f"gen_{i}"):
generator.train(steps, **kwargs)

@property
def logger(self) -> imit_logger.HierarchicalLogger:
return self._logger

@logger.setter
def logger(self, value: imit_logger.HierarchicalLogger):
self._logger = value
for generator in self.members:
generator.logger = value


def _get_trajectories(
trajectories: Sequence[TrajectoryWithRew],
steps: int,
Expand Down Expand Up @@ -1283,9 +1363,8 @@ def train(
# at the end of training (where the reward model is presumably best)
if i == self.num_iterations - 1:
num_steps += extra_timesteps
with self.logger.accumulate_means("agent"):
self.logger.log(f"Training agent for {num_steps} timesteps")
self.trajectory_generator.train(steps=num_steps)
self.logger.log(f"Training agent for {num_steps} timesteps")
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
self.trajectory_generator.train(steps=num_steps)

self.logger.dump(self._iteration)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def train_defaults():
save_preferences = False # save preference dataset at the end?
agent_path = None # path to a (partially) trained agent to load at the beginning
# type of PreferenceGatherer to use
num_agents = 1 # The number of agents to train the reward against.
gatherer_cls = preference_comparisons.SyntheticGatherer
# arguments passed on to the PreferenceGatherer specified by gatherer_cls
gatherer_kwargs = {}
Expand Down
33 changes: 28 additions & 5 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def train_preference_comparisons(
trajectory_generator_kwargs: Mapping[str, Any],
save_preferences: bool,
agent_path: Optional[str],
num_agents: int,
cross_entropy_loss_kwargs: Mapping[str, Any],
reward_trainer_kwargs: Mapping[str, Any],
gatherer_cls: Type[preference_comparisons.PreferenceGatherer],
Expand Down Expand Up @@ -110,6 +111,7 @@ def train_preference_comparisons(
save_preferences: if True, store the final dataset of preferences to disk.
agent_path: if given, initialize the agent using this stored policy
rather than randomly.
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
num_agents: number of agents to train the reward model against.
cross_entropy_loss_kwargs: passed to CrossEntropyRewardLoss
reward_trainer_kwargs: passed to BasicRewardTrainer or EnsembleRewardTrainer
gatherer_cls: type of PreferenceGatherer to use (defaults to SyntheticGatherer)
Expand Down Expand Up @@ -155,22 +157,43 @@ def train_preference_comparisons(
relabel_reward_fn=relabel_reward_fn,
)

if trajectory_path is None:
if num_agents < 1:
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("num_agents must be at least 1!")

def make_agent_trainer(seed: Optional[int] = None):
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
if agent_path is None:
agent = rl_common.make_rl_algo(venv)
else:
agent = rl_common.load_rl_algo_from_path(agent_path=agent_path, venv=venv)

# Setting the logger here is not really necessary (PreferenceComparisons
# takes care of that automatically) but it avoids creating unnecessary loggers
trajectory_generator = preference_comparisons.AgentTrainer(
return preference_comparisons.AgentTrainer(
algorithm=agent,
reward_fn=reward_net,
venv=venv,
exploration_frac=exploration_frac,
seed=_seed,
seed=_seed if seed is None else seed,
custom_logger=custom_logger,
**trajectory_generator_kwargs,
)

if trajectory_path is None and num_agents == 1:
trajectory_generator = make_agent_trainer()
# Stable Baselines will automatically occupy GPU 0 if it is available. Let's use
# the same device as the SB3 agent for the reward model.
reward_net = reward_net.to(trajectory_generator.algorithm.device)
allow_save_policy = True
elif trajectory_path is None and num_agents > 1:
members = [make_agent_trainer(_seed * i) for i in range(num_agents)]
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
trajectory_generator = preference_comparisons.MixtureOfTrajectoryGenerators(
members=members,
custom_logger=custom_logger,
)
reward_net = reward_net.to(members[0].algorithm.device)
allow_save_policy = False
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
else:
allow_save_policy = False
if exploration_frac > 0:
raise ValueError(
"exploration_frac can't be set when a trajectory dataset is used",
Expand Down Expand Up @@ -225,7 +248,7 @@ def save_callback(iteration_num):
save_checkpoint(
trainer=main_trainer,
save_path=os.path.join(log_dir, "checkpoints", f"{iteration_num:04d}"),
allow_save_policy=bool(trajectory_path is None),
allow_save_policy=allow_save_policy,
)

results = main_trainer.train(
Expand All @@ -242,7 +265,7 @@ def save_callback(iteration_num):
save_checkpoint(
trainer=main_trainer,
save_path=os.path.join(log_dir, "checkpoints", "final"),
allow_save_policy=bool(trajectory_path is None),
allow_save_policy=allow_save_policy,
)

# Storing and evaluating policy only useful if we actually generate trajectory data
Expand Down
45 changes: 34 additions & 11 deletions src/imitation/util/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __init__(
self.default_logger = default_logger
self.current_logger = None
self._cached_loggers = {}
self._prefixes = []
self._subdir = None
self._name = None
self.format_strs = format_strs
super().__init__(folder=self.default_logger.dir, output_formats=[])

Expand All @@ -72,27 +74,44 @@ def _update_name_to_maps(self) -> None:
self.name_to_excluded = self._logger.name_to_excluded

@contextlib.contextmanager
def accumulate_means(self, subdir: types.AnyPath) -> Generator[None, None, None]:
def add_prefix(self, prefix: str) -> Generator[None, None, None]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the prefix is removed after leaving the context, and only one prefix is supported as input, why are you adding support for a list of prefixes? How and why would I use multiple prefixes? (Is the expectation that I should enter nested prefix contexts, as that might be quite hard to read and understand in practice, e.g. if this happens in different files or function calls.) I also think this prefix/name idea should be explained with an examlpe in e.g. the class-level docstring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding correctly how this is supposed to be used, entering this should be disallowed if one is using an accumulate_means context, otherwise that would mess with the path where the rest of the logs are being recorded. Do you agree and if so do you think you can add a way to throw an error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding correctly how this is supposed to be used, entering this should be disallowed if one is using an accumulate_means context, otherwise that would mess with the path where the rest of the logs are being recorded. Do you agree and if so do you think you can add a way to throw an error?

Agreed, I've added a runtime error for this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a doctest documenting how to use it.

"""Add a prefix to the subdirectory used to accumulate means.

Args:
prefix: The prefix to add to the named sub.

Yields:
None when the context manager is entered
"""
try:
self._prefixes.append(prefix)
yield
finally:
self._prefixes.pop()

@contextlib.contextmanager
def accumulate_means(self, name: types.AnyPath) -> Generator[None, None, None]:
"""Temporarily modifies this HierarchicalLogger to accumulate means values.

During this context, `self.record(key, value)` writes the "raw" values in
"{self.default_logger.log_dir}/{subdir}" under the key "raw/{subdir}/{key}".
At the same time, any call to `self.record` will also accumulate mean values
on the default logger by calling
`self.default_logger.record_mean(f"mean/{subdir}/{key}", value)`.
"{self.default_logger.log_dir}/{prefix}/{name}" under the key
"raw/{prefix}/{name}/{key}". At the same time, any call to `self.record` will
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
also accumulate mean values on the default logger by calling
`self.default_logger.record_mean(f"mean/{prefix}/{name}/{key}", value)`.

During the context, `self.record(key, value)` will write the "raw" values in
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
`"{self.default_logger.log_dir}/subdir"` under the key "raw/{subdir}/key".
`"{self.default_logger.log_dir}/name"` under the key
"raw/{prefix}/{name}/key".

After the context exits, calling `self.dump()` will write the means
of all the "raw" values accumulated during this context to
`self.default_logger` under keys with the prefix `mean/{subdir}/`
`self.default_logger` under keys with the prefix `mean/{prefix}/{name}/`

Note that the behavior of other logging methods, `log` and `record_mean`
are unmodified and will go straight to the default logger.

Args:
subdir: A string key which determines the `folder` where raw data is
name: A string key which determines the `folder` where raw data is
written and temporary logging prefixes for raw and mean data. Entering
an `accumulate_means` context in the future with the same `subdir`
will safely append to logs written in this folder rather than
Expand All @@ -108,10 +127,12 @@ def accumulate_means(self, subdir: types.AnyPath) -> Generator[None, None, None]
if self.current_logger is not None:
raise RuntimeError("Nested `accumulate_means` context")

name = types.path_to_str(name)
levmckinney marked this conversation as resolved.
Show resolved Hide resolved
subdir = os.path.join(*self._prefixes, name)
levmckinney marked this conversation as resolved.
Show resolved Hide resolved

if subdir in self._cached_loggers:
logger = self._cached_loggers[subdir]
else:
subdir = types.path_to_str(subdir)
folder = os.path.join(self.default_logger.dir, "raw", subdir)
os.makedirs(folder, exist_ok=True)
output_formats = _build_output_formats(folder, self.format_strs)
Expand All @@ -121,20 +142,22 @@ def accumulate_means(self, subdir: types.AnyPath) -> Generator[None, None, None]
try:
self.current_logger = logger
self._subdir = subdir
self._name = name
self._update_name_to_maps()
yield
finally:
self.current_logger = None
self._subdir = None
self._name = None
self._update_name_to_maps()

def record(self, key, val, exclude=None):
if self.current_logger is not None: # In accumulate_means context.
assert self._subdir is not None
raw_key = "/".join(["raw", self._subdir, key])
raw_key = "/".join(["raw", *self._prefixes, self._name, key])
self.current_logger.record(raw_key, val, exclude)

mean_key = "/".join(["mean", self._subdir, key])
mean_key = "/".join(["mean", *self._prefixes, self._name, key])
self.default_logger.record_mean(mean_key, val, exclude)
else: # Not in accumulate_means context.
self.default_logger.record(key, val, exclude)
Expand Down
34 changes: 34 additions & 0 deletions tests/algorithms/test_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
from typing import Sequence
from unittest import mock

import numpy as np
import pytest
Expand Down Expand Up @@ -55,6 +56,17 @@ def fragmenter():
return preference_comparisons.RandomFragmenter(seed=0, warning_threshold=0)


@pytest.fixture
def trajectory_with_reward() -> TrajectoryWithRew:
return TrajectoryWithRew(
obs=np.zeros((33, 10), dtype=float),
acts=np.zeros(32, dtype=int),
infos=None,
terminal=False,
rews=np.zeros(32, dtype=float),
)


@pytest.fixture
def agent_trainer(agent, reward_net, venv):
return preference_comparisons.AgentTrainer(agent, reward_net, venv)
Expand Down Expand Up @@ -171,6 +183,28 @@ def test_transitions_left_in_buffer(agent_trainer):
agent_trainer.train(steps=1)


def test_mixture_of_trajectory_generators_train_and_sample(trajectory_with_reward):
gen_1 = mock.Mock(spec=preference_comparisons.TrajectoryGenerator)
gen_2 = mock.Mock(spec=preference_comparisons.TrajectoryGenerator)
gen_1.sample.return_value = 6 * [trajectory_with_reward]
gen_2.sample.return_value = 6 * [trajectory_with_reward]
mixture = preference_comparisons.MixtureOfTrajectoryGenerators(
members=(gen_1, gen_2),
)
mixture.train(steps=10, foo=4)
assert gen_1.train.called_once_with(steps=10, foo=4)
assert gen_2.train.called_once_with(steps=10, foo=4)
mixture.sample(11)
assert gen_1.sample.call_args.args[0] + gen_2.sample.call_args.args[0] == 11


def test_mixture_of_trajectory_generators_raises_value_error_when_members_is_empty():
with pytest.raises(ValueError):
preference_comparisons.MixtureOfTrajectoryGenerators(
members=[],
)


@pytest.mark.parametrize(
"schedule",
["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)],
Expand Down
Loading