Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major modeling refactoring #165

Merged
merged 33 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
81074f0
[Feat] add entropy calculation
fedebotu Apr 23, 2024
bb64ef1
[Feat] action logprob evaluation
fedebotu Apr 23, 2024
44c4901
[Minor] remove unused_kwarg for clarity
fedebotu Apr 23, 2024
fbd4941
[Rename] embedding_dim -> embed_dim (PyTorch naming convention)
fedebotu Apr 23, 2024
6e07985
[Move] move common one level up
fedebotu Apr 23, 2024
f30c32d
[Refactor] classify NCO as constructive (AR,NAR), improvement, search
fedebotu Apr 23, 2024
3a16b7c
[Refactor] follow major refactoring
fedebotu Apr 23, 2024
3ec285e
[Refactor] cleaner implementation; eval via policy itself
fedebotu Apr 23, 2024
796d54a
[Refactor] make env_name an optional kwarg
fedebotu Apr 23, 2024
faab06e
[Tests] adapt to refactoring
fedebotu Apr 23, 2024
5d04dfa
[Refactor] new structure; env_name as optional; embed_dim standardiza…
fedebotu Apr 23, 2024
4e6351c
[Tests] minor fix
fedebotu Apr 23, 2024
10cc4ee
Fixing best solution gathering for POMO
ahottung Apr 24, 2024
81a3bf9
Fixing bug introduced in last commit
ahottung Apr 25, 2024
7034172
Merge remote-tracking branch 'origin/main' into refactor-base
fedebotu Apr 27, 2024
3644acb
[BugFix] default POMO parameters
fedebotu Apr 27, 2024
cd62442
[Rename] Search -> Transductive
fedebotu Apr 27, 2024
4180997
[Feat] add NARGNN (as in DeepACO) as a separate policy and encoder
fedebotu Apr 27, 2024
e783679
[Refactor] abstract classes with abc.ABCMeta
fedebotu Apr 27, 2024
5a4740f
[Refactor] abstract classes with abc.ABCMeta
fedebotu Apr 27, 2024
3adbef4
[Feat] modular Critic network
fedebotu Apr 28, 2024
db06207
[Rename] PPOModel -> AMPPO
fedebotu Apr 28, 2024
9ef3254
[Refactor] separate A2C from classic REINFORCE #93
fedebotu Apr 28, 2024
ca44680
Merge remote-tracking branch 'origin/main' into refactor-base
fedebotu Apr 28, 2024
2c91457
[Minor] force env_name as str for clarity
fedebotu Apr 28, 2024
6da8691
[Tests] avoid testing render
fedebotu Apr 28, 2024
04ed94a
[Doc] add docstrings
fedebotu Apr 28, 2024
b7fe9b3
[BugFix] env_name not passed to base class
fedebotu Apr 28, 2024
3558d57
[Doc] update to latest version
fedebotu Apr 28, 2024
c3089fb
[Minor] woopsie, remove added exampels
fedebotu Apr 28, 2024
c1e19e8
[Minor] fix NAR; raise log error if any param is found in decoder
fedebotu Apr 28, 2024
90956af
[Doc] fix docstrings
fedebotu Apr 28, 2024
cfaf43d
[Doc] documentation update and improvements
fedebotu Apr 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/_content/api/algos/search.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Search

```{eval-rst}
.. automodule:: rl4co.models.zoo.common.search.base
.. automodule:: rl4co.models.common.search.base
:members:
:undoc-members:
```
12 changes: 6 additions & 6 deletions docs/_content/api/models/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ Autoregressive models are models that generate sequences one element at a time,
### Policy

```{eval-rst}
.. automodule:: rl4co.models.zoo.common.autoregressive.policy
.. automodule:: rl4co.models.common.constructive.autoregressive.policy
:members:
:undoc-members:
```

### Encoder

```{eval-rst}
.. automodule:: rl4co.models.zoo.common.autoregressive.encoder
.. automodule:: rl4co.models.common.constructive.autoregressive.encoder
:members:
:undoc-members:
```

### Decoder

```{eval-rst}
.. automodule:: rl4co.models.zoo.common.autoregressive.decoder
.. automodule:: rl4co.models.common.constructive.autoregressive.decoder
:members:
:undoc-members:
```
Expand All @@ -38,15 +38,15 @@ Non-autoregressive models generate a heatmap of probabilities from one node to a
### Policy

```{eval-rst}
.. automodule:: rl4co.models.zoo.common.nonautoregressive.policy
.. automodule:: rl4co.models.common.nonautoregressive.policy
:members:
:undoc-members:
```

### Encoder

```{eval-rst}
.. automodule:: rl4co.models.zoo.common.nonautoregressive.encoder
.. automodule:: rl4co.models.common.nonautoregressive.encoder
:members:
:undoc-members:
```
Expand All @@ -57,7 +57,7 @@ Note that we still need a decoding class for the heatmap (for example, to mask o


```{eval-rst}
.. automodule:: rl4co.models.zoo.common.nonautoregressive.decoder
.. automodule:: rl4co.models.common.nonautoregressive.decoder
:members:
:undoc-members:
```
Expand Down
37 changes: 14 additions & 23 deletions examples/1-quickstart.ipynb

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions examples/3-creating-new-env-model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,10 @@
" - locs: x, y coordinates of the cities\n",
" \"\"\"\n",
"\n",
" def __init__(self, embedding_dim, linear_bias=True):\n",
" def __init__(self, embed_dim, linear_bias=True):\n",
" super(TSPInitEmbedding, self).__init__()\n",
" node_dim = 2 # x, y\n",
" self.init_embed = nn.Linear(node_dim, embedding_dim, linear_bias)\n",
" self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias)\n",
"\n",
" def forward(self, td):\n",
" out = self.init_embed(td[\"locs\"])\n",
Expand Down Expand Up @@ -539,13 +539,13 @@
" - current node embedding\n",
" \"\"\"\n",
"\n",
" def __init__(self, embedding_dim, linear_bias=True):\n",
" def __init__(self, embed_dim, linear_bias=True):\n",
" super(TSPContext, self).__init__()\n",
" self.W_placeholder = nn.Parameter(\n",
" torch.Tensor(2 * embedding_dim).uniform_(-1, 1)\n",
" torch.Tensor(2 * embed_dim).uniform_(-1, 1)\n",
" )\n",
" self.project_context = nn.Linear(\n",
" embedding_dim*2, embedding_dim, bias=linear_bias\n",
" embed_dim*2, embed_dim, bias=linear_bias\n",
" )\n",
"\n",
" def forward(self, embeddings, td):\n",
Expand Down Expand Up @@ -620,7 +620,7 @@
"# Instantiate policy with the embeddings we created above\n",
"emb_dim = 128\n",
"policy = AutoregressivePolicy(env,\n",
" embedding_dim=emb_dim,\n",
" embed_dim=emb_dim,\n",
" init_embedding=TSPInitEmbedding(emb_dim),\n",
" context_embedding=TSPContext(emb_dim),\n",
" dynamic_embedding=StaticEmbedding(emb_dim)\n",
Expand Down
24 changes: 12 additions & 12 deletions examples/advanced/2-flash-attention-2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"from rl4co.envs import TSPEnv\n",
"from rl4co.models.zoo.am import AttentionModel\n",
"from rl4co.utils.trainer import RL4COTrainer\n",
"from rl4co.models.zoo.common.autoregressive import GraphAttentionEncoder\n",
"from rl4co.models.common.constructive.autoregressive import GraphAttentionEncoder\n",
"\n"
]
},
Expand Down Expand Up @@ -225,15 +225,15 @@
"env = TSPEnv(num_loc=1000)\n",
"\n",
"num_heads = 8\n",
"embedding_dim = 128\n",
"embed_dim = 128\n",
"num_layers = 3\n",
"enc_simple = GraphAttentionEncoder(env, num_heads=num_heads, embedding_dim=embedding_dim, num_layers=num_layers,\n",
"enc_simple = GraphAttentionEncoder(env, num_heads=num_heads, embed_dim=embed_dim, num_layers=num_layers,\n",
" sdpa_fn=scaled_dot_product_attention_simple)\n",
"\n",
"enc_fa1 = GraphAttentionEncoder(env, num_heads=num_heads, embedding_dim=embedding_dim, num_layers=num_layers,\n",
"enc_fa1 = GraphAttentionEncoder(env, num_heads=num_heads, embed_dim=embed_dim, num_layers=num_layers,\n",
" sdpa_fn=scaled_dot_product_attention)\n",
"\n",
"enc_fa2 = GraphAttentionEncoder(env, num_heads=num_heads, embedding_dim=embedding_dim, num_layers=num_layers,\n",
"enc_fa2 = GraphAttentionEncoder(env, num_heads=num_heads, embed_dim=embed_dim, num_layers=num_layers,\n",
" sdpa_fn=scaled_dot_product_attention_flash_attn)\n",
"\n",
"# Flash Attention supports only FP16 and BFloat16\n",
Expand All @@ -248,14 +248,14 @@
"metadata": {},
"outputs": [],
"source": [
"def build_models(num_heads=8, embedding_dim=128, num_layers=3):\n",
" enc_simple = GraphAttentionEncoder(env, num_heads=num_heads, embedding_dim=embedding_dim, num_layers=num_layers,\n",
"def build_models(num_heads=8, embed_dim=128, num_layers=3):\n",
" enc_simple = GraphAttentionEncoder(env, num_heads=num_heads, embed_dim=embed_dim, num_layers=num_layers,\n",
" sdpa_fn=scaled_dot_product_attention_simple)\n",
"\n",
" enc_fa1 = GraphAttentionEncoder(env, num_heads=num_heads, embedding_dim=embedding_dim, num_layers=num_layers,\n",
" enc_fa1 = GraphAttentionEncoder(env, num_heads=num_heads, embed_dim=embed_dim, num_layers=num_layers,\n",
" sdpa_fn=scaled_dot_product_attention)\n",
"\n",
" enc_fa2 = GraphAttentionEncoder(env, num_heads=num_heads, embedding_dim=embedding_dim, num_layers=num_layers,\n",
" enc_fa2 = GraphAttentionEncoder(env, num_heads=num_heads, embed_dim=embed_dim, num_layers=num_layers,\n",
" sdpa_fn=scaled_dot_product_attention_flash_attn)\n",
"\n",
" # Flash Attention supports only FP16 and BFloat16\n",
Expand Down Expand Up @@ -295,10 +295,10 @@
"times_fa1 = []\n",
"times_fa2 = []\n",
"\n",
"# for embedding_dim in [64, 128, 256]:\n",
"for embedding_dim in [128]:\n",
"# for embed_dim in [64, 128, 256]:\n",
"for embed_dim in [128]:\n",
" # Get models\n",
" enc_simple, enc_fa1, enc_fa2 = build_models(embedding_dim=embedding_dim)\n",
" enc_simple, enc_fa1, enc_fa2 = build_models(embed_dim=embed_dim)\n",
"\n",
" for problem_size in sizes:\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/modeling/1-decoding-strategies.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"\n",
"# Policy: neural network, in this case with encoder-decoder architecture\n",
"policy = AttentionModelPolicy(env.name, \n",
" embedding_dim=128,\n",
" embed_dim=128,\n",
" num_encoder_layers=3,\n",
" num_heads=8,\n",
" )\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/modeling/2-search-methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:rl4co.models.zoo.common.autoregressive.policy:Instantiated environment not provided; instantiating tsp\n"
"INFO:rl4co.models.common.constructive.autoregressive.policy:Instantiated environment not provided; instantiating tsp\n"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions examples/modeling/3-change-encoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,14 @@
"\n",
"gcn_encoder = GCNEncoder(\n",
" env_name='cvrp', \n",
" embedding_dim=128,\n",
" embed_dim=128,\n",
" num_nodes=20, \n",
" num_layers=3,\n",
")\n",
"\n",
"mpnn_encoder = MessagePassingEncoder(\n",
" env_name='cvrp', \n",
" embedding_dim=128,\n",
" embed_dim=128,\n",
" num_nodes=20, \n",
" num_layers=3,\n",
")\n",
Expand Down Expand Up @@ -464,15 +464,15 @@
" def __init__(\n",
" self,\n",
" env_name: str,\n",
" embedding_dim: int,\n",
" embed_dim: int,\n",
" init_embedding: nn.Module = None,\n",
" ):\n",
" super(BaseEncoder, self).__init__()\n",
" self.env_name = env_name\n",
" \n",
" # Init embedding for each environment\n",
" self.init_embedding = (\n",
" env_init_embedding(self.env_name, {\"embedding_dim\": embedding_dim})\n",
" env_init_embedding(self.env_name, {\"embed_dim\": embed_dim})\n",
" if init_embedding is None\n",
" else init_embedding\n",
" )\n",
Expand Down
2 changes: 1 addition & 1 deletion rl4co/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.0dev1"
__version__ = "0.4.0dev2"
26 changes: 17 additions & 9 deletions rl4co/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from rl4co.models.zoo.active_search import ActiveSearch
from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy
from rl4co.models.zoo.common.autoregressive import (
from rl4co.models.common.constructive.autoregressive import (
AutoregressiveDecoder,
AutoregressiveEncoder,
AutoregressivePolicy,
GraphAttentionEncoder,
)
from rl4co.models.zoo.common.nonautoregressive import (
from rl4co.models.common.constructive.base import (
ConstructiveDecoder,
ConstructiveEncoder,
ConstructivePolicy,
)
from rl4co.models.common.constructive.nonautoregressive import (
NonAutoregressiveDecoder,
NonAutoregressiveEncoder,
NonAutoregressiveModel,
NonAutoregressivePolicy,
)
from rl4co.models.zoo.common.search import SearchBase
from rl4co.models.common.search import SearchBase
from rl4co.models.rl.common.base import RL4COLitModule
from rl4co.models.rl.ppo.ppo import PPO
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, get_reinforce_baseline
from rl4co.models.rl.reinforce.reinforce import REINFORCE
from rl4co.models.zoo.active_search import ActiveSearch
from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy
from rl4co.models.zoo.deepaco import DeepACO, DeepACOPolicy
from rl4co.models.zoo.eas import EAS, EASEmb, EASLay
from rl4co.models.zoo.ham import (
Expand All @@ -20,7 +28,7 @@
)
from rl4co.models.zoo.matnet import MatNet, MatNetPolicy
from rl4co.models.zoo.mdam import MDAM, MDAMPolicy
from rl4co.models.zoo.pomo import POMO, POMOPolicy
from rl4co.models.zoo.ppo import PPOModel, PPOPolicy
from rl4co.models.zoo.pomo import POMO
from rl4co.models.zoo.ppo import PPOModel
from rl4co.models.zoo.ptrnet import PointerNetwork, PointerNetworkPolicy
from rl4co.models.zoo.symnco import SymNCO, SymNCOPolicy
15 changes: 15 additions & 0 deletions rl4co/models/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from rl4co.models.common.constructive.autoregressive import (
AutoregressiveDecoder,
AutoregressiveEncoder,
AutoregressivePolicy,
)
from rl4co.models.common.constructive.base import (
ConstructiveDecoder,
ConstructiveEncoder,
ConstructivePolicy,
)
from rl4co.models.common.constructive.nonautoregressive import (
NonAutoregressiveDecoder,
NonAutoregressiveEncoder,
NonAutoregressivePolicy,
)
15 changes: 15 additions & 0 deletions rl4co/models/common/constructive/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from rl4co.models.common.constructive.autoregressive import (
AutoregressiveDecoder,
AutoregressiveEncoder,
AutoregressivePolicy,
)
from rl4co.models.common.constructive.base import (
ConstructiveDecoder,
ConstructiveEncoder,
ConstructivePolicy,
)
from rl4co.models.common.constructive.nonautoregressive import (
NonAutoregressiveDecoder,
NonAutoregressiveEncoder,
NonAutoregressivePolicy,
)
3 changes: 3 additions & 0 deletions rl4co/models/common/constructive/autoregressive/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from rl4co.models.common.constructive.autoregressive.decoder import AutoregressiveDecoder
from rl4co.models.common.constructive.autoregressive.encoder import AutoregressiveEncoder
from rl4co.models.common.constructive.autoregressive.policy import AutoregressivePolicy
11 changes: 11 additions & 0 deletions rl4co/models/common/constructive/autoregressive/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from rl4co.models.common.constructive.base import ConstructiveDecoder


class AutoregressiveDecoder(ConstructiveDecoder):
"""Template class for an autoregressive decoder, simple wrapper around
:class: rl4co.models.common.constructive.base.ConstructiveDecoder

Tip:
This class will not work as it is and is just a template.
An example for autoregressive encoder can be found as :class: rl4co.models.zoo.am.decoder.AttentionModelDecoder.
"""
fedebotu marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 11 additions & 0 deletions rl4co/models/common/constructive/autoregressive/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from rl4co.models.common.constructive.base import ConstructiveEncoder


class AutoregressiveEncoder(ConstructiveEncoder):
"""Template class for an autoregressive encoder, simple wrapper around
:class: rl4co.models.common.constructive.base.ConstructiveEncoder

Tip:
This class will not work as it is and is just a template.
An example for autoregressive encoder can be found as :class: rl4co.models.zoo.am.encoder.AttentionModelEncoder.
fedebotu marked this conversation as resolved.
Show resolved Hide resolved
"""
49 changes: 49 additions & 0 deletions rl4co/models/common/constructive/autoregressive/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Union

from rl4co.envs import RL4COEnvBase
from rl4co.models.common.constructive.base import ConstructivePolicy

from .decoder import AutoregressiveDecoder
from .encoder import AutoregressiveEncoder


class AutoregressivePolicy(ConstructivePolicy):
"""Template class for an autoregressive policy, simple wrapper around
:class: rl4co.models.common.constructive.base.ConstructivePolicy.

Note:
While a decoder is required, an encoder is optional and will be initialized to
:class: rl4co.models.common.constructive.autoregressive.encoder.NoEncoder .
This can be used in decoder-only models in which at each step actions do not depend on
previously encoded states.
"""

def __init__(
self,
encoder: AutoregressiveEncoder = None,
decoder: AutoregressiveDecoder = None,
fedebotu marked this conversation as resolved.
Show resolved Hide resolved
env_name: Union[str, RL4COEnvBase] = "tsp",
temperature: float = 1.0,
tanh_clipping: float = 0,
mask_logits: bool = True,
train_decode_type: str = "sampling",
val_decode_type: str = "greedy",
test_decode_type: str = "greedy",
**unused_kw,
):
# We raise an error for the user if no decoder was provided
if decoder is None:
raise ValueError("AutoregressivePolicy requires a decoder to be provided.")

super(AutoregressivePolicy, self).__init__(
encoder=encoder,
decoder=decoder,
env_name=env_name,
temperature=temperature,
tanh_clipping=tanh_clipping,
mask_logits=mask_logits,
train_decode_type=train_decode_type,
val_decode_type=val_decode_type,
test_decode_type=test_decode_type,
**unused_kw,
)
Loading
Loading