Skip to content

Commit

Permalink
Add support for prefix context manager in logger (from #529) (#570)
Browse files Browse the repository at this point in the history
* Add support for prefix context manager in logger (from #529)

* Added back accidentally removed code

* Replaced preference comparisons prefix with ctx manager

* Fixed errors

* Docstring fixes

* Address PR comments

* Point SB3 to master to include bug fix

* Format / fix tests for context manager

* Switch to sb3 1.6.1

* Formatting

* Remove comment
  • Loading branch information
Rocamonde authored Sep 30, 2022
1 parent f00913d commit a738955
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 85 deletions.
10 changes: 1 addition & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,7 @@
"autorom[accept-rom-license]~=0.4.2",
]
PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else []
if IS_NOT_WINDOWS:
# TODO(adam): use this for Windows as well once PyPI is at >=1.6.1
STABLE_BASELINES3 = "stable-baselines3>=1.6.0"
else:
STABLE_BASELINES3 = (
"stable-baselines3@git+"
"https://github.com/DLR-RM/stable-baselines3.git@master"
)

STABLE_BASELINES3 = "stable-baselines3>=1.6.1"
# pinned to 0.21 until https://github.com/DLR-RM/stable-baselines3/pull/780 goes
# upstream.
GYM_VERSION_SPECIFIER = "==0.21.0"
Expand Down
113 changes: 57 additions & 56 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import pickle
import random
import re
from collections import defaultdict
from typing import (
Any,
Expand Down Expand Up @@ -1152,7 +1153,6 @@ def _train(
self,
dataset: PreferenceDataset,
epoch_multiplier: float = 1.0,
prefix: Optional[str] = None,
) -> None:
"""Trains for `epoch_multiplier * self.epochs` epochs over `dataset`."""
if self.regularizer is not None and self.regularizer.val_split is not None:
Expand Down Expand Up @@ -1180,71 +1180,65 @@ def _train(
epochs = round(self.epochs * epoch_multiplier)

assert epochs > 0, "Must train for at least one epoch."
epoch_num = 0
with self.logger.accumulate_means("reward"):
for epoch_num in tqdm(range(epochs), desc="Training reward model"):
logger_prefix = self._get_logger_key(prefix, f"epoch-{epoch_num}")
train_loss = 0.0
for fragment_pairs, preferences in dataloader:
self.optim.zero_grad()
loss = self._training_inner_loop(
fragment_pairs,
preferences,
prefix=f"{logger_prefix}/train",
)
train_loss += loss.item()
if self.regularizer:
self.regularizer.regularize_and_backward(loss)
else:
loss.backward()
self.optim.step()

if not self.requires_regularizer_update:
continue
assert val_dataloader is not None
assert self.regularizer is not None

val_loss = 0.0
for fragment_pairs, preferences in val_dataloader:
loss = self._training_inner_loop(
fragment_pairs,
preferences,
prefix=f"{logger_prefix}/val",
)
val_loss += loss.item()
self.regularizer.update_params(train_loss, val_loss)
with self.logger.add_key_prefix(f"epoch-{epoch_num}"):
train_loss = 0.0
for fragment_pairs, preferences in dataloader:
self.optim.zero_grad()
with self.logger.add_key_prefix("train"):
loss = self._training_inner_loop(
fragment_pairs,
preferences,
)
train_loss += loss.item()
if self.regularizer:
self.regularizer.regularize_and_backward(loss)
else:
loss.backward()
self.optim.step()

if not self.requires_regularizer_update:
continue
assert val_dataloader is not None
assert self.regularizer is not None

val_loss = 0.0
for fragment_pairs, preferences in val_dataloader:
with self.logger.add_key_prefix("val"):
val_loss += self._training_inner_loop(
fragment_pairs,
preferences,
).item()
self.regularizer.update_params(train_loss, val_loss)

# after training all the epochs,
# record also the final value in a separate key for easy access.
keys = list(self.logger.name_to_value.keys())
outer_prefix = self.logger.get_accumulate_prefixes()
for key in keys:
if key.startswith("mean/reward/" + logger_prefix):
base_path = f"{outer_prefix}reward/" # existing prefix + accum_means ctx
epoch_path = f"mean/{base_path}epoch-{epoch_num}/" # mean for last epoch
final_path = f"{base_path}final/" # path to record last epoch
pattern = rf"{epoch_path}(.+)"
if regex_match := re.match(pattern, key):
(key_name,) = regex_match.groups()
val = self.logger.name_to_value[key]
new_key = key.replace(
"mean/reward/" + logger_prefix,
"reward/" + self._get_logger_key(prefix, "final"),
)
new_key = f"{final_path}{key_name}"
self.logger.record(new_key, val)

def _training_inner_loop(
self,
fragment_pairs: Sequence[TrajectoryPair],
preferences: np.ndarray,
prefix: Optional[str] = None,
) -> th.Tensor:
output = self.loss.forward(fragment_pairs, preferences, self._preference_model)
loss = output.loss
self.logger.record(self._get_logger_key(prefix, "loss"), loss.item())
self.logger.record("loss", loss.item())
for name, value in output.metrics.items():
self.logger.record(self._get_logger_key(prefix, name), value.item())
self.logger.record(name, value.item())
return loss

# TODO(juan) refactor & remove once #529 is merged.
def _get_logger_key(self, mode: Optional[str], key: str) -> str:
if mode is None:
return key
return f"{mode}/{key}"


class EnsembleTrainer(BasicRewardTrainer):
"""Train a reward ensemble."""
Expand Down Expand Up @@ -1309,7 +1303,11 @@ def __init__(
if seed:
self.rng = self.rng.manual_seed(seed)

@RewardTrainer.logger.setter
@property
def logger(self):
return super().logger

@logger.setter
def logger(self, custom_logger):
self._logger = custom_logger
for member_trainer in self.member_trainers:
Expand All @@ -1326,20 +1324,20 @@ def _train(self, dataset: PreferenceDataset, epoch_multiplier: float = 1.0) -> N
for member_idx in range(len(self.member_trainers)):
# sampler gives new indexes on every call
bagging_dataset = data_th.Subset(dataset, list(sampler))
self.member_trainers[member_idx]._train(
bagging_dataset,
epoch_multiplier,
prefix=f"member-{member_idx}",
)
with self.logger.add_accumulate_prefix(f"member-{member_idx}"):
self.member_trainers[member_idx].train(
bagging_dataset,
epoch_multiplier=epoch_multiplier,
)

# average the metrics across the member models
metrics = defaultdict(list)
keys = list(self.logger.name_to_value.keys())
for key in keys:
if key.startswith("reward/member-") and "final" in key:
if re.match(r"member-(\d+)/reward/(.+)", key) and "final" in key:
val = self.logger.name_to_value[key]
key_list = key.split("/")
key_list.pop(1)
key_list.pop(0)
metrics["/".join(key_list)].append(val)

for k, v in metrics.items():
Expand Down Expand Up @@ -1599,8 +1597,11 @@ def train(
epoch_multiplier = self.initial_epoch_multiplier

self.reward_trainer.train(self.dataset, epoch_multiplier=epoch_multiplier)
reward_loss = self.logger.name_to_value["reward/final/train/loss"]
reward_accuracy = self.logger.name_to_value["reward/final/train/accuracy"]
base_key = self.logger.get_accumulate_prefixes() + "reward/final/train"
assert f"{base_key}/loss" in self.logger.name_to_value
assert f"{base_key}/accuracy" in self.logger.name_to_value
reward_loss = self.logger.name_to_value[f"{base_key}/loss"]
reward_accuracy = self.logger.name_to_value[f"{base_key}/accuracy"]

###################
# Train the agent #
Expand Down
Loading

0 comments on commit a738955

Please sign in to comment.