-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented the ability to train rewards in preference comparison against multiple policies #529
Conversation
Codecov Report
@@ Coverage Diff @@
## master #529 +/- ##
==========================================
+ Coverage 97.16% 97.19% +0.03%
==========================================
Files 85 85
Lines 7646 7769 +123
==========================================
+ Hits 7429 7551 +122
- Misses 217 218 +1
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
@levmckinney let me know if there's any particular areas you'd like my feedback on, I'll aim to try and do a high-level review but may not have time to look at it in detail. @Rocamonde can you do a style/technical review on Monday? |
When reading through the logging code, I see we are maintaining a somewhat advanced and general logging infrastructure for hierarchical logging. I'm surprised there's not an already existing way to handle this without a bespoke implementation. What are other people doing? (i.e. in similar ML projects) Are our logging needs very specific and different to the way this is normally done? It might be interesting to spin off the logging module into e.g. an SB3 PR, so people don't have to reinvent the wheel every time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial comments. Have not looked at the tests yet. Let me know what you think and when/if you implement any changes from my suggestions, and I can take another look.
@@ -72,27 +74,44 @@ def _update_name_to_maps(self) -> None: | |||
self.name_to_excluded = self._logger.name_to_excluded | |||
|
|||
@contextlib.contextmanager | |||
def accumulate_means(self, subdir: types.AnyPath) -> Generator[None, None, None]: | |||
def add_prefix(self, prefix: str) -> Generator[None, None, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the prefix is removed after leaving the context, and only one prefix is supported as input, why are you adding support for a list of prefixes? How and why would I use multiple prefixes? (Is the expectation that I should enter nested prefix contexts, as that might be quite hard to read and understand in practice, e.g. if this happens in different files or function calls.) I also think this prefix/name idea should be explained with an examlpe in e.g. the class-level docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm understanding correctly how this is supposed to be used, entering this should be disallowed if one is using an accumulate_means
context, otherwise that would mess with the path where the rest of the logs are being recorded. Do you agree and if so do you think you can add a way to throw an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm understanding correctly how this is supposed to be used, entering this should be disallowed if one is using an
accumulate_means
context, otherwise that would mess with the path where the rest of the logs are being recorded. Do you agree and if so do you think you can add a way to throw an error?
Agreed, I've added a runtime error for this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a doctest documenting how to use it.
What we are doing is kind of weird -- you don't normally have several RL trainers that all need to log to the same place -- but it could come up in some other applications, e.g. population-based training. You could look at how |
@AdamGleave Could you look at the changes to the HierarchicalLogger? |
Co-authored-by: Adam Gleave <[email protected]> Co-authored-by: Juan Rocamonde <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new changes LGTM. Do complete the docstring (see comment) before merging. Others might want to have a chance to take a look too before then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some minor comments, but overall looks good. Only did a shallow review. Let me know if there's anything specific I should pay closer attention to.
@@ -133,6 +134,9 @@ def train_preference_comparisons( | |||
be allocated to each iteration. "hyperbolic" and "inverse_quadratic" | |||
apportion fewer queries to later iterations when the policy is assumed | |||
to be better and more stable. | |||
share_training_steps_among_agents: If True (default), when training with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having default behavior preserve total training timesteps budget sems good to me. Integer division //
vs something more clever that handles remainders seems unimportant -- we usually have hundreds of thousands of timesteps and single-digit numbers of agents.
relabel_reward_fn=relabel_reward_fn, | ||
) | ||
else: | ||
agent = rl_common.load_rl_algo_from_path( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There'll be hilariously little diversity between agents in this case, but not much we can do there. (Support loading different agents I guess? But that seems overkill for what's a rare use case.)
else: | ||
single_agent = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not single agent in the sense that it's... zero agent? This is a bit counterintuitive.
@@ -251,7 +281,7 @@ def save_callback(iteration_num): | |||
"checkpoints", | |||
f"{iteration_num:04d}", | |||
), | |||
allow_save_policy=bool(trajectory_path is None), | |||
allow_save_policy=single_agent, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't save checkpoints if there's multiple agents? That's a little sad.
* Skeleton of regularization techniques * Added loss and weight regularizer, update_param fn protocol * Fix type error * Added logging * Renamed reg to regularization, passed logger down * Fixes linting issues * Silly linting error only visible on CircleCI * Update src/imitation/regularization/__init__.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/regularization/__init__.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/regularization/__init__.py Co-authored-by: Adam Gleave <[email protected]> * Make regularizer initialization control flow more readable * Typing improvements * Multiple fixes re: types and assertions * Improved input validation * Add support for Lp norms, remove singleton L1/L2 implementations * Add support for arbitrary regularizers, improve input validation * Typing and linting * Fix silly errata causing test to fail * Add missing docstring args * Restructured folder into submodules and improved typing by adding generics * Updated imports after restructuring folder * Added tests for updaters.py * Added tests for regularizers.py, except weight decay * Fixed tests for lp regularization * Added tests for weight decay * Linting / formatting * Linting / typing * Final tests to improve code coverage * Tweaks for code coverage * Tweaks for code coverage v2 * Fix logging issues * Formatting * Formatting (docstring) * Fix file open issue * Linting * Remove useless conversion to float * Replace assert with ValueError * Check for lambda being negative in the scaler * Guard against losses being negative in interval param scaler * Split tests up to cover more cases * Clean up repetitive code for readability * Remove old TODO message * Merge RewardTrainer seeds into one * Fix interval param tests to new error messages * Move regularization input validation to factory class * Remake the regularizer factory for better API design * Fix bugs and tests in new factory design * Added docstring to regularizer factory. * Update src/imitation/regularization/regularizers.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/regularization/regularizers.py Co-authored-by: Adam Gleave <[email protected]> * Add todo to refactor once #529 is merged. * Rename regularize to regularize_and_backward * Fix bug in tests and docstrings * Rename mode to prefix * Added exceptions to docstrings * Make type ignore only specific to pytype * Add verbatim double-`` to some docstrings * Change phrasing in docstring Co-authored-by: Adam Gleave <[email protected]>
Co-authored-by: Adam Gleave <[email protected]>
* Add support for prefix context manager in logger (from #529) * Added back accidentally removed code * Replaced preference comparisons prefix with ctx manager * Fixed errors * Docstring fixes * Address PR comments * Point SB3 to master to include bug fix * Format / fix tests for context manager * Switch to sb3 1.6.1 * Formatting * Remove comment
#534) * Initial mypy configuration * Fix types: test_envs.py * Fix types: conftest.py * Fix types: tests/util * Fix types: tests/scripts * Fix types: tests/rewards * Fix types: tests/policies * Incorrect decorator in update_stats method form networks.py::BaseNorm * Fix types: tests/algorithms (adersarial and bc) * Fix types: tests/algorithms (dagger and pc) * Fix types: tests/data * Linting * Linting * Fix types: algorithms/preference_comparisons.py * Fix types: algorithms/mce_irl.py * Formatting, fixed minor bug * Clarify why types are ignored * Started fixing types on algorithms/density.py * Linting * Linting (add back type ignore after reformatting) * Fixed types: imitation/data/types.py * Fixed types (started): imitation/data/ * Fixed types: imitation/data/buffer.py * Fixed bug in buffer.py * Fixed types: imitation/data/rollout.py * Fixed types: imitation/data/wrappers.py * Improve makefile to support automatic cache cleaning * Fixed types: imitation/testing/ * Linting, fixed wrong return type in rewards.predict_processed_all * Fixed types: imitation/policies/ * Formatting * Fixed types: imitation/rewards/ * Fixed types: imitation/rewards/ * Fixed types: imitation/scripts/ * Fixed types: imitation/util/ and formatting * Linting and formatting * Bug fixes for test errors * Linting and typing * Improve typing in algorithms * Formatting * Bug fix * Formatting * Fixes suggested by Adam. * Fix mypy version. * Update TabularPolicy.predict to match base class * Fix not checking for dones * Change for loop to dict comprehension * Remove is_ensemble to clear up type checking errors * Reduce code duplication and general cleanup * Fix type annotation of step_dict * Change List to Sequence * Fix density.py::DensityAlgorithm._set_demo_from_batch * Fixed n_steps (OnPolicyAlgorithm) * Fix errors in tests * Include some suggestions into rollout.py and preference_comparisons.py * Formatting * Fix setter error as per python/mypy#5936 * add reason for assertion. * Fix style guide violation: https://google.github.io/styleguide/pyguide.html#22-imports * Update src/imitation/scripts/parallel.py Co-authored-by: Adam Gleave <[email protected]> * Move kwargs to the end. * Swap order of expert_policy_type and expert_policy_path validation check * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Update tests/rewards/test_reward_fn.py Co-authored-by: Adam Gleave <[email protected]> * Explicit random state setting and fix corresponding tests (except notebooks, sacred config, scripts) * Fix notebooks; add script to clean notebooks * Fix all tests. * Formattting. * Additional fixes * Linting * Remove automatically generated `_api` docs files too on `make clean` * Fix docstrings. * Fix issue with next(iter(iterable)) * Formatting * Remove whitespace * Add TODO message to remove type ignore later * Remove unnecessary assertion. * Fixed types in density.py set_demonstrations * Added type ignore to pytype bug * Fix_get_first_iter_element and add tests * Bugfix in BC and tests -- masked as previously iterator ran out too early! * Remove makefile for now * Added link to SB3 issue for future reference. * Fix types of train_imitation Only return "expert_stats" if all trajectories have reward. * Modify assert in test_bc to reflect correct type * Add ci/clean_notebooks.py to CI checks * Improve clean_notebooks.py by allowing checking only mode. * Add ipynb notebook checks to CI * Add support for explicit files for notebook cleaning * Clean notebooks * Small improvements in util.py * Replace TransitionKind with TransitionsMinimal * Delete unused statement in test * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Make type ignore specific to pytype * Linting * Migrate from RandomState (deprecated) to Generator * Add backticks to error message * Create "AnyNorm" alias * Small fix * Add additional checks to shapes in _set_demo_from_batch * Fix RolloutStatsComputer type * Improved logging/messages in clean_notebooks.py * Fix issues resulting from merge * Bug fix * Bug fix (wasn't really fixed before) * Fixed docs example of BC * Fix bugs resulting from merge * Fix docs (dagger.rst) caught by sphinx CI * Add mypy to CI * Continue fixing miscellaneous type errors * Linting * Fix issue with normalize_input_layer type * Add support for checking presence of generic type ignores * Allow subdirectories in notebook clean * Add full typing support for TransitionsMinimal as a sequence * Fix types for density.py * Misc fixes * Add support for prefix context manager in logger (from #529) * Added back accidentally removed code * Replaced preference comparisons prefix with ctx manager * Fixed errors * Bug fixes * Docstring fixes * Fix bug in serialize.py * Fixed codecheck by pointing notebook checks to docs * Add rng to mce_irl.rst (doctest) * Add rng to density.rst (doctest) * Fix remaining rst files * Increase sample size to reduce flakiness * Ignore files not passing mypy for now * Comment in wrong line * Comment in wrong line * Move excluded files to argument * Add quotes to mypy arg call * Fix CI mypy call * Fix CI yaml * Break ignored files up into one line each * Address PR comments * Point SB3 to master to include bug fix * Do not follow imports for ignored files * Format / fix tests for context manager * Switch to sb3 1.6.1 * Formatting * Remove unused import * Remove unused fixture * Add coveragerc file * Add utils test * Add tests and asserts * Add test to synthetic gatherer * Add trajectory unwrap tests * Formatting * Remove bracket typo * Fix .coveragerc instruction * Improve density algo coverage and bug fixes * Fix bug in test * Add pragma no cover updates * Minor coverage tweaks * Fix iterator test * Update ci/check_typeignore.py Co-authored-by: Adam Gleave <[email protected]> * DRY clean_notebooks.py * Minor tweak in check_typeignore.py * Added all checks to CI * Move imports to top-level * Move main to main() function in script * Minor fixes * Remove tweak after new SB3 release * Split main() into helper functions. * Fix edge case of n=0 in seed generator * Update src/imitation/scripts/common/rl.py Co-authored-by: Adam Gleave <[email protected]> * Fix general type ignore in src * Fix type ignore errors * Formatting * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/scripts/common/rl.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Replace todo with todo+issue link * Add explicit type ignore arg * Add excluded files to code_checks.sh * Unbreak line * Misc fixes * Add training to the density algorithm * Fix no attribute error * Type ignore to `with raises` test * Remove unused import * Check typeignore for all SRC files * Clean notebooks * Remove unused import * Ignore the file itself from typeignore check * Add exception to docstring * Fix bad naming for clean_notebooks Co-authored-by: Adam Gleave <[email protected]>
* Initial mypy configuration * Initial change to get the PR up * Initial review at replacing os.path * Bug fixes from tests * Fix types: test_envs.py * Fix types: conftest.py * Fix types: tests/util * Fix types: tests/scripts * Fix types: tests/rewards * Fix types: tests/policies * Incorrect decorator in update_stats method form networks.py::BaseNorm * Fix types: tests/algorithms (adersarial and bc) * Fix types: tests/algorithms (dagger and pc) * Fix types: tests/data * Linting * Linting * Fix types: algorithms/preference_comparisons.py * Fix types: algorithms/mce_irl.py * Formatting, fixed minor bug * Clarify why types are ignored * Started fixing types on algorithms/density.py * Linting * Linting (add back type ignore after reformatting) * Fixed types: imitation/data/types.py * Fixed types (started): imitation/data/ * Fixed types: imitation/data/buffer.py * Fixed bug in buffer.py * Fixed types: imitation/data/rollout.py * Fixed types: imitation/data/wrappers.py * Improve makefile to support automatic cache cleaning * Fixed types: imitation/testing/ * Linting, fixed wrong return type in rewards.predict_processed_all * Fixed types: imitation/policies/ * Formatting * Fixed types: imitation/rewards/ * Fixed types: imitation/rewards/ * Fixed types: imitation/scripts/ * Fixed types: imitation/util/ and formatting * Linting and formatting * Bug fixes for test errors * Linting and typing * Improve typing in algorithms * Formatting * Bug fix * Formatting * Fixes suggested by Adam. * Fix mypy version. * Fix bugs * Remove unused imports * Formatting * Added parse_path func and refactored code to use it * Fix typing, linting * Update TabularPolicy.predict to match base class * Fix not checking for dones * Change for loop to dict comprehension * Remove is_ensemble to clear up type checking errors * Reduce code duplication and general cleanup * Fix type annotation of step_dict * Change List to Sequence * Fix density.py::DensityAlgorithm._set_demo_from_batch * Fixed n_steps (OnPolicyAlgorithm) * Fix errors in tests * Include some suggestions into rollout.py and preference_comparisons.py * Formatting * Fix setter error as per python/mypy#5936 * add reason for assertion. * Fix style guide violation: https://google.github.io/styleguide/pyguide.html#22-imports * Update src/imitation/scripts/parallel.py Co-authored-by: Adam Gleave <[email protected]> * Move kwargs to the end. * Swap order of expert_policy_type and expert_policy_path validation check * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Update tests/rewards/test_reward_fn.py Co-authored-by: Adam Gleave <[email protected]> * Explicit random state setting and fix corresponding tests (except notebooks, sacred config, scripts) * Fix notebooks; add script to clean notebooks * Fix all tests. * Formattting. * Additional fixes * Linting * Remove automatically generated `_api` docs files too on `make clean` * Fix docstrings. * Fix issue with next(iter(iterable)) * Formatting * Remove whitespace * Add TODO message to remove type ignore later * Remove unnecessary assertion. * Fixed types in density.py set_demonstrations * Added type ignore to pytype bug * Fix_get_first_iter_element and add tests * Bugfix in BC and tests -- masked as previously iterator ran out too early! * Remove makefile for now * Added link to SB3 issue for future reference. * Fix types of train_imitation Only return "expert_stats" if all trajectories have reward. * Modify assert in test_bc to reflect correct type * Add ci/clean_notebooks.py to CI checks * Improve clean_notebooks.py by allowing checking only mode. * Add ipynb notebook checks to CI * Add support for explicit files for notebook cleaning * Clean notebooks * Small improvements in util.py * Replace TransitionKind with TransitionsMinimal * Delete unused statement in test * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/util/util.py Co-authored-by: Adam Gleave <[email protected]> * Make type ignore specific to pytype * Linting * Migrate from RandomState (deprecated) to Generator * Add backticks to error message * Create "AnyNorm" alias * Small fix * Add additional checks to shapes in _set_demo_from_batch * Fix RolloutStatsComputer type * Improved logging/messages in clean_notebooks.py * Fix issues resulting from merge * Bug fix * Bug fix (wasn't really fixed before) * Fixed docs example of BC * Fix bugs resulting from merge * Fix docs (dagger.rst) caught by sphinx CI * Add mypy to CI * Continue fixing miscellaneous type errors * Linting * Fix issue with normalize_input_layer type * Add support for checking presence of generic type ignores * Allow subdirectories in notebook clean * Add full typing support for TransitionsMinimal as a sequence * Fix types for density.py * Misc fixes * Add support for prefix context manager in logger (from #529) * Added back accidentally removed code * Replaced preference comparisons prefix with ctx manager * Fixed errors * Bug fixes * Docstring fixes * Fix bug in serialize.py * Fixed codecheck by pointing notebook checks to docs * Add rng to mce_irl.rst (doctest) * Add rng to density.rst (doctest) * Fix remaining rst files * Increase sample size to reduce flakiness * Ignore files not passing mypy for now * Comment in wrong line * Comment in wrong line * Move excluded files to argument * Add quotes to mypy arg call * Fix CI mypy call * Fix CI yaml * Break ignored files up into one line each * Address PR comments * Point SB3 to master to include bug fix * Small bug fixes * Small bug fixes * Sort import * Linting * Do not follow imports for ignored files * Fix tests for context managers * Format / fix tests for context manager * Switch to sb3 1.6.1 * Formatting * Upgrade Python version in Windows CI * Remove unused import * Remove unused fixture * Add coveragerc file * Add utils test * Add tests and asserts * Add test to synthetic gatherer * Add trajectory unwrap tests * Formatting * Remove bracket typo * Fix .coveragerc instruction * Improve density algo coverage and bug fixes * Fix bug in test * Add pragma no cover updates * Minor coverage tweaks * Fix iterator test * Add test for parse_path * Updates on sacred util * Mark type ignore rule * Mark type ignore rule * Miscellaneous bug fixes and improvements * Reformat hanging line * Ignore parse path checks for windows * Add trailing comma * Minor changes * No newline end of file * Update src/imitation/data/types.py Co-authored-by: Adam Gleave <[email protected]> * Update src/imitation/data/types.py Co-authored-by: Adam Gleave <[email protected]> * Include suggestions from Adam Co-authored-by: Adam Gleave <[email protected]>
What's the status of this? I borrowed from this PR partially in other PRs that have been merged to avoid blockage / reinventing the wheel, so it probably requires some re-writing. |
@levmckinney, what's outstanding here? I know you're busy now, but I can likely find someone else to finish the PR off. However, IIRC the policy ensemble didn't help performance much, so perhaps it's served its purpose (of running those experiments) and we should just cherry-pick any helpful features from it then close this PR. |
I think we can close this. The main thing we might want to cherry-pick out is the changes to the hierarchical logger. However, that only really matters if we are planning to move forward with this logging framework. |
Description
This pull request adds the
MixtureOfTrajectoryGenerators
class which allows preference comparison to be trained on data from multiple sources. This includes the ability to train against multiple agents and train against a mixture of data from an agent and a dataset.To facilitate this I have made modifications to
HeirachicalLogger
class allowing for arbitrary prefixes to be added so that different trajectory generators logs don't clobber each other.Testing
Changes to the
HeirachicalLogger
class have been unit tested.MixtureOfTrajectoryGenerators
has also been united test. I've also included new integration tests for thetrain_preference_comparison.py
script to test the use case of having multiple policies.