-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Major modeling refactoring #165
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
81074f0
[Feat] add entropy calculation
fedebotu bb64ef1
[Feat] action logprob evaluation
fedebotu 44c4901
[Minor] remove unused_kwarg for clarity
fedebotu fbd4941
[Rename] embedding_dim -> embed_dim (PyTorch naming convention)
fedebotu 6e07985
[Move] move common one level up
fedebotu f30c32d
[Refactor] classify NCO as constructive (AR,NAR), improvement, search
fedebotu 3a16b7c
[Refactor] follow major refactoring
fedebotu 3ec285e
[Refactor] cleaner implementation; eval via policy itself
fedebotu 796d54a
[Refactor] make env_name an optional kwarg
fedebotu faab06e
[Tests] adapt to refactoring
fedebotu 5d04dfa
[Refactor] new structure; env_name as optional; embed_dim standardizaβ¦
fedebotu 4e6351c
[Tests] minor fix
fedebotu 10cc4ee
Fixing best solution gathering for POMO
ahottung 81a3bf9
Fixing bug introduced in last commit
ahottung 7034172
Merge remote-tracking branch 'origin/main' into refactor-base
fedebotu 3644acb
[BugFix] default POMO parameters
fedebotu cd62442
[Rename] Search -> Transductive
fedebotu 4180997
[Feat] add NARGNN (as in DeepACO) as a separate policy and encoder
fedebotu e783679
[Refactor] abstract classes with abc.ABCMeta
fedebotu 5a4740f
[Refactor] abstract classes with abc.ABCMeta
fedebotu 3adbef4
[Feat] modular Critic network
fedebotu db06207
[Rename] PPOModel -> AMPPO
fedebotu 9ef3254
[Refactor] separate A2C from classic REINFORCE #93
fedebotu ca44680
Merge remote-tracking branch 'origin/main' into refactor-base
fedebotu 2c91457
[Minor] force env_name as str for clarity
fedebotu 6da8691
[Tests] avoid testing render
fedebotu 04ed94a
[Doc] add docstrings
fedebotu b7fe9b3
[BugFix] env_name not passed to base class
fedebotu 3558d57
[Doc] update to latest version
fedebotu c3089fb
[Minor] woopsie, remove added exampels
fedebotu c1e19e8
[Minor] fix NAR; raise log error if any param is found in decoder
fedebotu 90956af
[Doc] fix docstrings
fedebotu cfaf43d
[Doc] documentation update and improvements
fedebotu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# A2C | ||
|
||
## A2C (Advantage Actor Critic) | ||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.rl.a2c.a2c | ||
:members: | ||
:undoc-members: | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,4 @@ | |
.. automodule:: rl4co.models.rl.common.base | ||
:members: | ||
:undoc-members: | ||
``` | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,9 @@ | ||
# PPO | ||
|
||
|
||
## PPO (Proximal Policy Optimization) | ||
|
||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.rl.ppo.ppo | ||
:members: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,6 @@ | |
:undoc-members: | ||
``` | ||
|
||
--- | ||
|
||
## Baselines | ||
|
||
```{eval-rst} | ||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Decoding Strategies | ||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.utils.decoding | ||
:members: | ||
:undoc-members: | ||
``` |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# NCO Methods Overview | ||
|
||
|
||
We categorize NCO approaches (which are in fact not necessarily trained with RL!) into the following: 1) constructive, 2) improvement, 3) transductive. | ||
|
||
|
||
```{eval-rst} | ||
.. tip:: | ||
Note that in RL4CO we distinguish the RL algorithms and the actors via the following naming: | ||
|
||
* **Model:** Refers to the reinforcement learning algorithm encapsulated within a `LightningModule`. This module is responsible for training the policy. | ||
* **Policy:** Implemented as a `nn.Module`, this neural network (often referred to as the *actor*) takes an instance and outputs a sequence of actions, :math:`\pi = \pi_0, \pi_1, \dots, \pi_N`, which constitutes the solution. | ||
|
||
Here, :math:`\pi_i` represents the action taken at step :math:`i`, forming a sequence that leads to the optimal or near-optimal solution for the given instance. | ||
``` | ||
|
||
|
||
The following table contains the categorization that we follow in RL4CO: | ||
|
||
|
||
```{eval-rst} | ||
.. list-table:: Overview of RL Models and Policies | ||
:widths: 5 5 5 5 25 | ||
:header-rows: 1 | ||
:stub-columns: 1 | ||
|
||
* - Category | ||
- Model or Policy? | ||
- Input | ||
- Output | ||
- Description | ||
* - Constructive | ||
- Policy | ||
- Instance | ||
- Solution | ||
- Policies trained to generate solutions from scratch. Can be categorized into AutoRegressive (AR) and Non-Autoregressive (NAR). | ||
* - Improvement | ||
- Policy | ||
- Instance, Current Solution | ||
- Improved Solution | ||
- Policies trained to improve existing solutions iteratively, akin to local search algorithms. They focus on refining *existing* solutions rather than generating them from scratch. | ||
* - Transductive | ||
- Model | ||
- Instance, (Policy) | ||
- Solution, (Updated Policy) | ||
- Updates policy parameters during online testing to improve solutions of a specific instance. | ||
``` | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
## Constructive Policies | ||
|
||
Constructive NCO policies pre-train a policy to amortize the inference. "Constructive" means that a solution is created from scratch by the model. We can also categorize constructive NCO in two sub-categories depending on the role of encoder and decoder: | ||
|
||
#### Autoregressive (AR) | ||
Autoregressive approaches **use a learned decoder** that outputs log probabilities for the current solution. These approaches generate a solution step by step, similar to e.g. LLMs. They have an encoder-decoder structure. Some models may not have an encoder at all and just re-encode at each step. | ||
|
||
#### NonAutoregressive (NAR) | ||
The difference between AR and NAR approaches is that NAR **only an encoder is learnable** (they just encode in one shot) and generate for example a heatmap, which can then be decoded simply by using it as a probability distribution or by using some search method on top. | ||
|
||
Here is a general structure of a general constructive policy with an encoder-decoder structure: | ||
|
||
<img class="full-img" alt="policy" src="https://user-images.githubusercontent.com/48984123/281976545-ca88f159-d0b3-459e-8fd9-89799be9d1b0.png"> | ||
|
||
|
||
where _embeddings_ transfer information from feature space to embedding space. | ||
|
||
--- | ||
|
||
|
||
|
||
### Constructive Policy Base Classes | ||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.common.constructive.base | ||
:members: | ||
:undoc-members: | ||
``` | ||
|
||
|
||
|
||
### Autoregressive Policies Base Classes | ||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.common.constructive.autoregressive.encoder | ||
:members: | ||
:undoc-members: | ||
``` | ||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.common.constructive.autoregressive.decoder | ||
:members: | ||
:undoc-members: | ||
``` | ||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.common.constructive.autoregressive.policy | ||
:members: | ||
:undoc-members: | ||
``` | ||
|
||
### Nonautoregressive Policies Base Classes | ||
|
||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.common.constructive.nonautoregressive.encoder | ||
:members: | ||
:undoc-members: | ||
``` | ||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.common.constructive.nonautoregressive.decoder | ||
:members: | ||
:undoc-members: | ||
``` | ||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.common.constructive.nonautoregressive.policy | ||
:members: | ||
:undoc-members: | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
## Improvement Policies | ||
|
||
These methods differ w.r.t. constructive NCO since they can obtain better solutions similarly to how local search algorithms work - they can improve the solutions over time. This is different from decoding strategies or similar in constructive methods since these policies are trained for performing improvement operations. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Transductive Models | ||
|
||
|
||
Transductive models are learning algorithms that optimize on a specific instance. They improve solutions by updating policy parameters $\theta$_, which means that we are running optimization (backprop) **at test time**. Transductive learning can be performed with different policies: for example EAS updates (a part of) AR policies parameters to obtain better solutions, but I guess there are ways (or papers out there I don't know of) that optimize at test time. | ||
|
||
|
||
```{eval-rst} | ||
.. tip:: | ||
You may refer to the definition of `inductive vs transductive RL <https://en.wikipedia.org/wiki/Transduction_(machine_learning)>`_. In inductive RL, we train to generalize to new instances. In transductive RL we train (or finetune) to solve only specific ones. | ||
``` | ||
|
||
|
||
## Base Transductive Model | ||
|
||
```{eval-rst} | ||
.. automodule:: rl4co.models.common.transductive.base | ||
:members: | ||
:undoc-members: | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could mention here or somewhere else that abstract classes under
rl4co/models/common
are not expected to be directly initialized. For example, if you want to use an autoregressive policy, you may want to init an AM model instead of theAutoregressivePolicy()
, same as NAR, improvement, and transductive classes.