Skip to content

Commit

Permalink
Merge pull request #169 from ai4co/refactor-env-base
Browse files Browse the repository at this point in the history
Major environment refactoring (base version)
  • Loading branch information
fedebotu authored May 1, 2024
2 parents 97e90b9 + 862731b commit f7c984c
Show file tree
Hide file tree
Showing 57 changed files with 4,472 additions and 3,080 deletions.
27 changes: 21 additions & 6 deletions rl4co/envs/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ def step(self, td: TensorDict) -> TensorDict:
# Since we simplify the syntax
return self._torchrl_step(td)

def reset(self, td: Optional[TensorDict] = None, batch_size = None) -> TensorDict:
"""Reset function to call at the beginning of each episode"""
if batch_size is None:
batch_size = self.batch_size if td is None else td.batch_size
if td is None or td.is_empty():
td = self.generator(batch_size=batch_size)
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
self.to(td.device)
return super().reset(td, batch_size=batch_size)

def _torchrl_step(self, td: TensorDict) -> TensorDict:
"""See :meth:`super().step` for more details.
This is the usual way to do it in TorchRL, but inefficient in our case
Expand Down Expand Up @@ -167,6 +177,15 @@ def _make_spec(self, td_params: TensorDict = None):
raise NotImplementedError

def get_reward(self, td, actions) -> TensorDict:
"""Function to compute the reward. Can be called by the agent to compute the reward of the current state
This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks
"""
if self.check_solution:
self.check_solution_validity(td, actions)
return self._get_reward(td, actions)

@abc.abstractmethod
def _get_reward(self, td, actions) -> TensorDict:
"""Function to compute the reward. Can be called by the agent to compute the reward of the current state
This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks
"""
Expand Down Expand Up @@ -200,7 +219,7 @@ def dataset(self, batch_size=[], phase="train", filename=None):
if f is None:
if phase != "train":
log.warning(f"{phase}_file not set. Generating dataset instead")
td = self.generate_data(batch_size)
td = self.generator(batch_size)
else:
log.info(f"Loading {phase} dataset from {f}")
if phase == "train":
Expand All @@ -222,14 +241,10 @@ def dataset(self, batch_size=[], phase="train", filename=None):
f"Provided file name {f} not found. Make sure to provide a file in the right path first or "
f"unset {phase}_file to generate data automatically instead"
)
td = self.generate_data(batch_size)
td = self.generator(batch_size)

return self.dataset_cls(td)

def generate_data(self, batch_size):
"""Dataset generation"""
raise NotImplementedError

def transform(self):
"""Used for converting TensorDict variables (such as with torch.cat) efficiently
https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.Transform.html
Expand Down
58 changes: 56 additions & 2 deletions rl4co/envs/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional
from typing import Optional, Callable, Union

import torch

from tensordict.tensordict import TensorDictBase
from torch.distributions import Uniform, Normal, Exponential, Poisson
from tensordict.tensordict import TensorDictBase, TensorDict
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec


Expand Down Expand Up @@ -44,3 +45,56 @@ def _getstate_env(self):
state = self.__dict__.copy()
del state["rng"]
return state


class Generator():
def __init__(self, **kwargs):
self.kwargs = kwargs

def __call__(self, batch_size) -> TensorDict:
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
return self._generate(batch_size)

def _generate(self, batch_size, **kwargs) -> TensorDict:
raise NotImplementedError


def get_sampler(
val_name: str,
distribution: Union[int, float, str, type, Callable],
min_val: float,
max_val: float,
**kwargs
):
"""Get the sampler for the variable with the given distribution
Args:
val_name: Name of the variable
distribution: int/float value (as constant distribution), or string with the distribution name (supporting
uniform, normal, exponential, and poisson) or PyTorch Distribution type or a callable function that
returns a PyTorch Distribution
min_val: Minimum value for the variable, used for Uniform distribution
max_val: Maximum value for the variable, used for Uniform distribution
kwargs: Additional arguments for the distribution
"""
if isinstance(distribution, (int, float)):
return Uniform(low=distribution, high=distribution)
elif distribution == "center": # Depot
return Uniform(low=(max_val-min_val)/2, high=(max_val-min_val)/2)
elif distribution == "corner": # Depot
return Uniform(low=min_val, high=min_val)
elif distribution == Uniform or distribution == "uniform":
return Uniform(low=min_val, high=max_val)
elif distribution == Normal or distribution == "normal":
assert kwargs.get("mean_"+val_name, None) is not None, "mean is required for Normal distribution"
assert kwargs.get(val_name+"_std", None) is not None, "std is required for Normal distribution"
return Normal(mean=kwargs[val_name+"_mean"], std=kwargs[val_name+"_std"])
elif distribution == Exponential or distribution == "exponential":
assert kwargs.get(val_name+"_rate", None) is not None, "rate is required for Exponential/Poisson distribution"
return Exponential(rate=kwargs[val_name+"_rate"])
elif distribution == Poisson or distribution == "poisson":
assert kwargs.get(val_name+"_rate", None) is not None, "rate is required for Exponential/Poisson distribution"
return Poisson(rate=kwargs[val_name+"_rate"])
elif isinstance(distribution, Callable):
return distribution(**kwargs)
else:
raise ValueError(f"Invalid distribution type of {distribution}")
4 changes: 2 additions & 2 deletions rl4co/envs/eda/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from rl4co.envs.eda.dpp import DPPEnv
from rl4co.envs.eda.mdpp import MDPPEnv
from rl4co.envs.eda.dpp.env import DPPEnv
from rl4co.envs.eda.mdpp.env import MDPPEnv
Loading

0 comments on commit f7c984c

Please sign in to comment.