From 849f7e7a9630c07fe7e99c2e308ab43eb881c7ea Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Mon, 10 Oct 2022 15:13:05 +0200 Subject: [PATCH] Revert "Add support for prefix context manager in logger (from #529) (#570)" This reverts commit a7389559fc48370d6361a6b42bb4b050aff2f11a. --- setup.py | 10 +- .../algorithms/preference_comparisons.py | 113 ++++++----- src/imitation/util/logger.py | 176 ++---------------- tests/util/test_logger.py | 94 +--------- 4 files changed, 85 insertions(+), 308 deletions(-) diff --git a/setup.py b/setup.py index f15479dad..01c41e084 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,15 @@ "autorom[accept-rom-license]~=0.4.2", ] PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] -STABLE_BASELINES3 = "stable-baselines3>=1.6.1" +if IS_NOT_WINDOWS: + # TODO(adam): use this for Windows as well once PyPI is at >=1.6.1 + STABLE_BASELINES3 = "stable-baselines3>=1.6.0" +else: + STABLE_BASELINES3 = ( + "stable-baselines3@git+" + "https://github.com/DLR-RM/stable-baselines3.git@master" + ) + # pinned to 0.21 until https://github.com/DLR-RM/stable-baselines3/pull/780 goes # upstream. GYM_VERSION_SPECIFIER = "==0.21.0" diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 7d37b337b..092f3eeda 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -7,7 +7,6 @@ import math import pickle import random -import re from collections import defaultdict from typing import ( Any, @@ -1153,6 +1152,7 @@ def _train( self, dataset: PreferenceDataset, epoch_multiplier: float = 1.0, + prefix: Optional[str] = None, ) -> None: """Trains for `epoch_multiplier * self.epochs` epochs over `dataset`.""" if self.regularizer is not None and self.regularizer.val_split is not None: @@ -1180,65 +1180,71 @@ def _train( epochs = round(self.epochs * epoch_multiplier) assert epochs > 0, "Must train for at least one epoch." + epoch_num = 0 with self.logger.accumulate_means("reward"): for epoch_num in tqdm(range(epochs), desc="Training reward model"): - with self.logger.add_key_prefix(f"epoch-{epoch_num}"): - train_loss = 0.0 - for fragment_pairs, preferences in dataloader: - self.optim.zero_grad() - with self.logger.add_key_prefix("train"): - loss = self._training_inner_loop( - fragment_pairs, - preferences, - ) - train_loss += loss.item() - if self.regularizer: - self.regularizer.regularize_and_backward(loss) - else: - loss.backward() - self.optim.step() - - if not self.requires_regularizer_update: - continue - assert val_dataloader is not None - assert self.regularizer is not None - - val_loss = 0.0 - for fragment_pairs, preferences in val_dataloader: - with self.logger.add_key_prefix("val"): - val_loss += self._training_inner_loop( - fragment_pairs, - preferences, - ).item() - self.regularizer.update_params(train_loss, val_loss) + logger_prefix = self._get_logger_key(prefix, f"epoch-{epoch_num}") + train_loss = 0.0 + for fragment_pairs, preferences in dataloader: + self.optim.zero_grad() + loss = self._training_inner_loop( + fragment_pairs, + preferences, + prefix=f"{logger_prefix}/train", + ) + train_loss += loss.item() + if self.regularizer: + self.regularizer.regularize_and_backward(loss) + else: + loss.backward() + self.optim.step() + + if not self.requires_regularizer_update: + continue + assert val_dataloader is not None + assert self.regularizer is not None + + val_loss = 0.0 + for fragment_pairs, preferences in val_dataloader: + loss = self._training_inner_loop( + fragment_pairs, + preferences, + prefix=f"{logger_prefix}/val", + ) + val_loss += loss.item() + self.regularizer.update_params(train_loss, val_loss) # after training all the epochs, # record also the final value in a separate key for easy access. keys = list(self.logger.name_to_value.keys()) - outer_prefix = self.logger.get_accumulate_prefixes() for key in keys: - base_path = f"{outer_prefix}reward/" # existing prefix + accum_means ctx - epoch_path = f"mean/{base_path}epoch-{epoch_num}/" # mean for last epoch - final_path = f"{base_path}final/" # path to record last epoch - pattern = rf"{epoch_path}(.+)" - if regex_match := re.match(pattern, key): - (key_name,) = regex_match.groups() + if key.startswith("mean/reward/" + logger_prefix): val = self.logger.name_to_value[key] - new_key = f"{final_path}{key_name}" + new_key = key.replace( + "mean/reward/" + logger_prefix, + "reward/" + self._get_logger_key(prefix, "final"), + ) self.logger.record(new_key, val) def _training_inner_loop( self, fragment_pairs: Sequence[TrajectoryPair], preferences: np.ndarray, + prefix: Optional[str] = None, ) -> th.Tensor: output = self.loss.forward(fragment_pairs, preferences, self._preference_model) loss = output.loss - self.logger.record("loss", loss.item()) + self.logger.record(self._get_logger_key(prefix, "loss"), loss.item()) for name, value in output.metrics.items(): - self.logger.record(name, value.item()) + self.logger.record(self._get_logger_key(prefix, name), value.item()) return loss + # TODO(juan) refactor & remove once #529 is merged. + def _get_logger_key(self, mode: Optional[str], key: str) -> str: + if mode is None: + return key + return f"{mode}/{key}" + class EnsembleTrainer(BasicRewardTrainer): """Train a reward ensemble.""" @@ -1303,11 +1309,7 @@ def __init__( if seed: self.rng = self.rng.manual_seed(seed) - @property - def logger(self): - return super().logger - - @logger.setter + @RewardTrainer.logger.setter def logger(self, custom_logger): self._logger = custom_logger for member_trainer in self.member_trainers: @@ -1324,20 +1326,20 @@ def _train(self, dataset: PreferenceDataset, epoch_multiplier: float = 1.0) -> N for member_idx in range(len(self.member_trainers)): # sampler gives new indexes on every call bagging_dataset = data_th.Subset(dataset, list(sampler)) - with self.logger.add_accumulate_prefix(f"member-{member_idx}"): - self.member_trainers[member_idx].train( - bagging_dataset, - epoch_multiplier=epoch_multiplier, - ) + self.member_trainers[member_idx]._train( + bagging_dataset, + epoch_multiplier, + prefix=f"member-{member_idx}", + ) # average the metrics across the member models metrics = defaultdict(list) keys = list(self.logger.name_to_value.keys()) for key in keys: - if re.match(r"member-(\d+)/reward/(.+)", key) and "final" in key: + if key.startswith("reward/member-") and "final" in key: val = self.logger.name_to_value[key] key_list = key.split("/") - key_list.pop(0) + key_list.pop(1) metrics["/".join(key_list)].append(val) for k, v in metrics.items(): @@ -1597,11 +1599,8 @@ def train( epoch_multiplier = self.initial_epoch_multiplier self.reward_trainer.train(self.dataset, epoch_multiplier=epoch_multiplier) - base_key = self.logger.get_accumulate_prefixes() + "reward/final/train" - assert f"{base_key}/loss" in self.logger.name_to_value - assert f"{base_key}/accuracy" in self.logger.name_to_value - reward_loss = self.logger.name_to_value[f"{base_key}/loss"] - reward_accuracy = self.logger.name_to_value[f"{base_key}/accuracy"] + reward_loss = self.logger.name_to_value["reward/final/train/loss"] + reward_accuracy = self.logger.name_to_value["reward/final/train/accuracy"] ################### # Train the agent # diff --git a/src/imitation/util/logger.py b/src/imitation/util/logger.py index ea59fed6e..b60b4d74e 100644 --- a/src/imitation/util/logger.py +++ b/src/imitation/util/logger.py @@ -5,7 +5,7 @@ import os import sys import tempfile -from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Generator, Optional, Sequence, Tuple, Union import stable_baselines3.common.logger as sb_logger @@ -16,7 +16,7 @@ def make_output_format( _format: str, log_dir: str, log_suffix: str = "", - max_length: int = 50, + max_length: int = 40, ) -> sb_logger.KVWriter: """Returns a logger for the requested format. @@ -72,58 +72,8 @@ class HierarchicalLogger(sb_logger.Logger): `self.accumulate_means` creates a context manager. While in this context, values are loggged to a sub-logger, with only mean values recorded in the top-level (root) logger. - - >>> import tempfile - >>> with tempfile.TemporaryDirectory() as dir: - ... logger: HierarchicalLogger = configure(dir, ('log',)) - ... # record the key value pair (loss, 1.0) to path `dir` - ... # at step 1. - ... logger.record("loss", 1.0) - ... logger.dump(step=1) - ... with logger.accumulate_means("dataset"): - ... # record the key value pair `("raw/dataset/entropy", 5.0)` to path - ... # `dir/raw/dataset` at step 100 - ... logger.record("entropy", 5.0) - ... logger.dump(step=100) - ... # record the key value pair `("raw/dataset/entropy", 6.0)` to path - ... # `dir/raw/dataset` at step 200 - ... logger.record("entropy", 6.0) - ... logger.dump(step=200) - ... # record the key value pair `("mean/dataset/entropy", 5.5)` to path - ... # `dir` at step 1. - ... logger.dump(step=1) - ... with logger.add_accumulate_prefix("foo"), logger.accumulate_means("bar"): - ... # record the key value pair ("raw/foo/bar/biz", 42.0) to path - ... # `dir/raw/foo/bar` at step 2000 - ... logger.record("biz", 42.0) - ... logger.dump(step=2000) - ... # record the key value pair `("mean/foo/bar/biz", 42.0)` to path - ... # `dir` at step 1. - ... logger.dump(step=1) - ... with open(os.path.join(dir, 'log.txt')) as f: - ... print(f.read()) - ------------------- - | loss | 1 | - ------------------- - --------------------------------- - | mean/ | | - | dataset/entropy | 5.5 | - --------------------------------- - ----------------------------- - | mean/ | | - | foo/bar/biz | 42 | - ----------------------------- - """ - default_logger: sb_logger.Logger - current_logger: Optional[sb_logger.Logger] - _cached_loggers: Dict[str, sb_logger.Logger] - _accumulate_prefixes: List[str] - _key_prefixes: List[str] - _subdir: Optional[str] - _name: Optional[str] - def __init__( self, default_logger: sb_logger.Logger, @@ -143,10 +93,7 @@ def __init__( self.default_logger = default_logger self.current_logger = None self._cached_loggers = {} - self._accumulate_prefixes = [] - self._key_prefixes = [] self._subdir = None - self._name = None self.format_strs = format_strs super().__init__(folder=self.default_logger.dir, output_formats=[]) @@ -156,95 +103,27 @@ def _update_name_to_maps(self) -> None: self.name_to_excluded = self._logger.name_to_excluded @contextlib.contextmanager - def add_accumulate_prefix(self, prefix: str) -> Generator[None, None, None]: - """Add a prefix to the subdirectory used to accumulate means. - - This prefix only applies when a `accumulate_means` context is active. If there - are multiple active prefixes, then they are concatenated. - - Args: - prefix: The prefix to add to the named sub. - - Yields: - None when the context manager is entered - - Raises: - RuntimeError: if accumulate means context is already active. - """ - if self.current_logger is not None: - raise RuntimeError( - "Cannot add prefix when accumulate_means context is already active.", - ) - - try: - self._accumulate_prefixes.append(prefix) - yield - finally: - self._accumulate_prefixes.pop() - - def get_accumulate_prefixes(self) -> str: - prefixes = "/".join(self._accumulate_prefixes) - return prefixes + "/" if prefixes else "" - - @contextlib.contextmanager - def add_key_prefix(self, prefix: str) -> Generator[None, None, None]: - """Add a prefix to the keys logged during an accumulate_means context. - - This prefix only applies when a `accumulate_means` context is active. - If there are multiple active prefixes, then they are concatenated. - - Args: - prefix: The prefix to add to the keys. - - Yields: - None when the context manager is entered - - Raises: - RuntimeError: if accumulate means context is already active. - """ - if self.current_logger is None: - raise RuntimeError( - "Cannot add key prefix when accumulate_means context is not active.", - ) + def accumulate_means(self, subdir: types.AnyPath) -> Generator[None, None, None]: + """Temporarily modifies this HierarchicalLogger to accumulate means values. - try: - self._key_prefixes.append(prefix) - yield - finally: - self._key_prefixes.pop() + During this context, `self.record(key, value)` writes the "raw" values in + "{self.default_logger.log_dir}/{subdir}" under the key "raw/{subdir}/{key}". + At the same time, any call to `self.record` will also accumulate mean values + on the default logger by calling + `self.default_logger.record_mean(f"mean/{subdir}/{key}", value)`. - @contextlib.contextmanager - def accumulate_means(self, name: str) -> Generator[None, None, None]: - """Temporarily modifies this HierarchicalLogger to accumulate means values. + During the context, `self.record(key, value)` will write the "raw" values in + `"{self.default_logger.log_dir}/subdir"` under the key "raw/{subdir}/key". - Within this context manager, ``self.record(key, value)`` writes the "raw" values - in ``f"{self.default_logger.log_dir}/[{accumulate_prefix}/]{name}"`` under the - key ``"raw/[{accumulate_prefix}/]{name}/[{key_prefix}/]{key}"``, where - ``accumulate_prefix`` is the concatenation of all prefixes added by - ``add_accumulate_prefix`` and ``key_prefix`` is the concatenation of all - prefixes added by ``add_key_prefix``, if any. At the same time, any call to - ``self.record`` will also accumulate mean values on the default logger by - calling:: - - self.default_logger.record_mean( - f"mean/[{accumulate_prefix}/]{name}/[{key_prefix}/]{key}", - value, - ) - - Multiple prefixes may be active at once. In this case the `prefix` is simply the - concatenation of each of the active prefixes in the order they - were created e.g. if the active prefixes are ``['foo', 'bar']`` then - the prefix is ``'foo/bar'``. - - After the context exits, calling ``self.dump()`` will write the means + After the context exits, calling `self.dump()` will write the means of all the "raw" values accumulated during this context to - ``self.default_logger`` under keys of the form ``mean/{prefix}/{name}/{key}`` + `self.default_logger` under keys with the prefix `mean/{subdir}/` - Note that the behavior of other logging methods, ``log`` and ``record_mean`` + Note that the behavior of other logging methods, `log` and `record_mean` are unmodified and will go straight to the default logger. Args: - name: A string key which determines the ``folder`` where raw data is + subdir: A string key which determines the `folder` where raw data is written and temporary logging prefixes for raw and mean data. Entering an `accumulate_means` context in the future with the same `subdir` will safely append to logs written in this folder rather than @@ -260,11 +139,10 @@ def accumulate_means(self, name: str) -> Generator[None, None, None]: if self.current_logger is not None: raise RuntimeError("Nested `accumulate_means` context") - subdir = os.path.join(*self._accumulate_prefixes, name) - if subdir in self._cached_loggers: logger = self._cached_loggers[subdir] else: + subdir = types.path_to_str(subdir) folder = os.path.join(self.default_logger.dir, "raw", subdir) os.makedirs(folder, exist_ok=True) output_formats = _build_output_formats(folder, self.format_strs) @@ -274,38 +152,20 @@ def accumulate_means(self, name: str) -> Generator[None, None, None]: try: self.current_logger = logger self._subdir = subdir - self._name = name self._update_name_to_maps() yield finally: self.current_logger = None self._subdir = None - self._name = None self._update_name_to_maps() def record(self, key, val, exclude=None): if self.current_logger is not None: # In accumulate_means context. assert self._subdir is not None - raw_key = "/".join( - [ - "raw", - *self._accumulate_prefixes, - self._name, - *self._key_prefixes, - key, - ], - ) + raw_key = "/".join(["raw", self._subdir, key]) self.current_logger.record(raw_key, val, exclude) - mean_key = "/".join( - [ - "mean", - *self._accumulate_prefixes, - self._name, - *self._key_prefixes, - key, - ], - ) + mean_key = "/".join(["mean", self._subdir, key]) self.default_logger.record_mean(mean_key, val, exclude) else: # Not in accumulate_means context. self.default_logger.record(key, val, exclude) diff --git a/tests/util/test_logger.py b/tests/util/test_logger.py index 7f076df9c..4048c44d3 100644 --- a/tests/util/test_logger.py +++ b/tests/util/test_logger.py @@ -163,8 +163,8 @@ def test_name_to_value(tmpdir): def test_hard(tmpdir): hier_logger = logger.configure(tmpdir) - # Part One: Test logging outside the accumulating scope, and within scopes - # with two different logging keys (including a repeat). + # Part One: Test logging outside of the accumulating scope, and within scopes + # with two different different logging keys (including a repeat). hier_logger.record("no_context", 1) @@ -229,93 +229,3 @@ def test_hard(tmpdir): _compare_csv_lines(osp.join(tmpdir, "progress.csv"), expect_default) _compare_csv_lines(osp.join(tmpdir, "raw", "gen", "progress.csv"), expect_raw_gen) _compare_csv_lines(osp.join(tmpdir, "raw", "disc", "progress.csv"), expect_raw_disc) - - -def test_accumulate_prefix(tmpdir): - hier_logger = logger.configure(tmpdir) - - with hier_logger.add_accumulate_prefix("foo"), hier_logger.accumulate_means("bar"): - hier_logger.record("A", 1) - hier_logger.record("B", 2) - hier_logger.dump() - - hier_logger.record("no_context", 1) - - with hier_logger.accumulate_means("blat"): - hier_logger.record("C", 3) - hier_logger.dump() - - hier_logger.dump() - - expect_raw_foo_bar = { - "raw/foo/bar/A": [1], - "raw/foo/bar/B": [2], - } - expect_raw_blat = { - "raw/blat/C": [3], - } - expect_default = { - "mean/foo/bar/A": [1], - "mean/foo/bar/B": [2], - "mean/blat/C": [3], - "no_context": [1], - } - - _compare_csv_lines(osp.join(tmpdir, "progress.csv"), expect_default) - _compare_csv_lines( - osp.join(tmpdir, "raw", "foo", "bar", "progress.csv"), - expect_raw_foo_bar, - ) - _compare_csv_lines(osp.join(tmpdir, "raw", "blat", "progress.csv"), expect_raw_blat) - - -def test_key_prefix(tmpdir): - hier_logger = logger.configure(tmpdir) - - with hier_logger.accumulate_means("foo"), hier_logger.add_key_prefix("bar"): - hier_logger.record("A", 1) - hier_logger.record("B", 2) - hier_logger.dump() - - hier_logger.record("no_context", 1) - - with hier_logger.accumulate_means("blat"): - hier_logger.record("C", 3) - hier_logger.dump() - - hier_logger.dump() - - expect_raw_foo_bar = { - "raw/foo/bar/A": [1], - "raw/foo/bar/B": [2], - } - expect_raw_blat = { - "raw/blat/C": [3], - } - expect_default = { - "mean/foo/bar/A": [1], - "mean/foo/bar/B": [2], - "mean/blat/C": [3], - "no_context": [1], - } - - _compare_csv_lines(osp.join(tmpdir, "progress.csv"), expect_default) - _compare_csv_lines( - osp.join(tmpdir, "raw", "foo", "progress.csv"), - expect_raw_foo_bar, - ) - _compare_csv_lines(osp.join(tmpdir, "raw", "blat", "progress.csv"), expect_raw_blat) - - -def test_cant_add_prefix_within_accumulate_means(tmpdir): - h = logger.configure(tmpdir) - with pytest.raises(RuntimeError): - with h.accumulate_means("foo"), h.add_accumulate_prefix("bar"): - pass # pragma: no cover - - -def test_cant_add_key_prefix_outside_accumulate_means(tmpdir): - h = logger.configure(tmpdir) - with pytest.raises(RuntimeError): - with h.add_key_prefix("bar"): - pass # pragma: no cover