Skip to content

Commit

Permalink
Revert "Add support for prefix context manager in logger (from #529) (#…
Browse files Browse the repository at this point in the history
…570)"

This reverts commit a738955.
  • Loading branch information
dfilan committed Oct 10, 2022
1 parent a738955 commit 849f7e7
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 308 deletions.
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
"autorom[accept-rom-license]~=0.4.2",
]
PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else []
STABLE_BASELINES3 = "stable-baselines3>=1.6.1"
if IS_NOT_WINDOWS:
# TODO(adam): use this for Windows as well once PyPI is at >=1.6.1
STABLE_BASELINES3 = "stable-baselines3>=1.6.0"
else:
STABLE_BASELINES3 = (
"stable-baselines3@git+"
"https://github.com/DLR-RM/stable-baselines3.git@master"
)

# pinned to 0.21 until https://github.com/DLR-RM/stable-baselines3/pull/780 goes
# upstream.
GYM_VERSION_SPECIFIER = "==0.21.0"
Expand Down
113 changes: 56 additions & 57 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import math
import pickle
import random
import re
from collections import defaultdict
from typing import (
Any,
Expand Down Expand Up @@ -1153,6 +1152,7 @@ def _train(
self,
dataset: PreferenceDataset,
epoch_multiplier: float = 1.0,
prefix: Optional[str] = None,
) -> None:
"""Trains for `epoch_multiplier * self.epochs` epochs over `dataset`."""
if self.regularizer is not None and self.regularizer.val_split is not None:
Expand Down Expand Up @@ -1180,65 +1180,71 @@ def _train(
epochs = round(self.epochs * epoch_multiplier)

assert epochs > 0, "Must train for at least one epoch."
epoch_num = 0
with self.logger.accumulate_means("reward"):
for epoch_num in tqdm(range(epochs), desc="Training reward model"):
with self.logger.add_key_prefix(f"epoch-{epoch_num}"):
train_loss = 0.0
for fragment_pairs, preferences in dataloader:
self.optim.zero_grad()
with self.logger.add_key_prefix("train"):
loss = self._training_inner_loop(
fragment_pairs,
preferences,
)
train_loss += loss.item()
if self.regularizer:
self.regularizer.regularize_and_backward(loss)
else:
loss.backward()
self.optim.step()

if not self.requires_regularizer_update:
continue
assert val_dataloader is not None
assert self.regularizer is not None

val_loss = 0.0
for fragment_pairs, preferences in val_dataloader:
with self.logger.add_key_prefix("val"):
val_loss += self._training_inner_loop(
fragment_pairs,
preferences,
).item()
self.regularizer.update_params(train_loss, val_loss)
logger_prefix = self._get_logger_key(prefix, f"epoch-{epoch_num}")
train_loss = 0.0
for fragment_pairs, preferences in dataloader:
self.optim.zero_grad()
loss = self._training_inner_loop(
fragment_pairs,
preferences,
prefix=f"{logger_prefix}/train",
)
train_loss += loss.item()
if self.regularizer:
self.regularizer.regularize_and_backward(loss)
else:
loss.backward()
self.optim.step()

if not self.requires_regularizer_update:
continue
assert val_dataloader is not None
assert self.regularizer is not None

val_loss = 0.0
for fragment_pairs, preferences in val_dataloader:
loss = self._training_inner_loop(
fragment_pairs,
preferences,
prefix=f"{logger_prefix}/val",
)
val_loss += loss.item()
self.regularizer.update_params(train_loss, val_loss)

# after training all the epochs,
# record also the final value in a separate key for easy access.
keys = list(self.logger.name_to_value.keys())
outer_prefix = self.logger.get_accumulate_prefixes()
for key in keys:
base_path = f"{outer_prefix}reward/" # existing prefix + accum_means ctx
epoch_path = f"mean/{base_path}epoch-{epoch_num}/" # mean for last epoch
final_path = f"{base_path}final/" # path to record last epoch
pattern = rf"{epoch_path}(.+)"
if regex_match := re.match(pattern, key):
(key_name,) = regex_match.groups()
if key.startswith("mean/reward/" + logger_prefix):
val = self.logger.name_to_value[key]
new_key = f"{final_path}{key_name}"
new_key = key.replace(
"mean/reward/" + logger_prefix,
"reward/" + self._get_logger_key(prefix, "final"),
)
self.logger.record(new_key, val)

def _training_inner_loop(
self,
fragment_pairs: Sequence[TrajectoryPair],
preferences: np.ndarray,
prefix: Optional[str] = None,
) -> th.Tensor:
output = self.loss.forward(fragment_pairs, preferences, self._preference_model)
loss = output.loss
self.logger.record("loss", loss.item())
self.logger.record(self._get_logger_key(prefix, "loss"), loss.item())
for name, value in output.metrics.items():
self.logger.record(name, value.item())
self.logger.record(self._get_logger_key(prefix, name), value.item())
return loss

# TODO(juan) refactor & remove once #529 is merged.
def _get_logger_key(self, mode: Optional[str], key: str) -> str:
if mode is None:
return key
return f"{mode}/{key}"


class EnsembleTrainer(BasicRewardTrainer):
"""Train a reward ensemble."""
Expand Down Expand Up @@ -1303,11 +1309,7 @@ def __init__(
if seed:
self.rng = self.rng.manual_seed(seed)

@property
def logger(self):
return super().logger

@logger.setter
@RewardTrainer.logger.setter
def logger(self, custom_logger):
self._logger = custom_logger
for member_trainer in self.member_trainers:
Expand All @@ -1324,20 +1326,20 @@ def _train(self, dataset: PreferenceDataset, epoch_multiplier: float = 1.0) -> N
for member_idx in range(len(self.member_trainers)):
# sampler gives new indexes on every call
bagging_dataset = data_th.Subset(dataset, list(sampler))
with self.logger.add_accumulate_prefix(f"member-{member_idx}"):
self.member_trainers[member_idx].train(
bagging_dataset,
epoch_multiplier=epoch_multiplier,
)
self.member_trainers[member_idx]._train(
bagging_dataset,
epoch_multiplier,
prefix=f"member-{member_idx}",
)

# average the metrics across the member models
metrics = defaultdict(list)
keys = list(self.logger.name_to_value.keys())
for key in keys:
if re.match(r"member-(\d+)/reward/(.+)", key) and "final" in key:
if key.startswith("reward/member-") and "final" in key:
val = self.logger.name_to_value[key]
key_list = key.split("/")
key_list.pop(0)
key_list.pop(1)
metrics["/".join(key_list)].append(val)

for k, v in metrics.items():
Expand Down Expand Up @@ -1597,11 +1599,8 @@ def train(
epoch_multiplier = self.initial_epoch_multiplier

self.reward_trainer.train(self.dataset, epoch_multiplier=epoch_multiplier)
base_key = self.logger.get_accumulate_prefixes() + "reward/final/train"
assert f"{base_key}/loss" in self.logger.name_to_value
assert f"{base_key}/accuracy" in self.logger.name_to_value
reward_loss = self.logger.name_to_value[f"{base_key}/loss"]
reward_accuracy = self.logger.name_to_value[f"{base_key}/accuracy"]
reward_loss = self.logger.name_to_value["reward/final/train/loss"]
reward_accuracy = self.logger.name_to_value["reward/final/train/accuracy"]

###################
# Train the agent #
Expand Down
Loading

0 comments on commit 849f7e7

Please sign in to comment.