diff --git a/rl4co/envs/common/base.py b/rl4co/envs/common/base.py index 432a338c..12b2b347 100644 --- a/rl4co/envs/common/base.py +++ b/rl4co/envs/common/base.py @@ -129,6 +129,16 @@ def step(self, td: TensorDict) -> TensorDict: # Since we simplify the syntax return self._torchrl_step(td) + def reset(self, td: Optional[TensorDict] = None, batch_size = None) -> TensorDict: + """Reset function to call at the beginning of each episode""" + if batch_size is None: + batch_size = self.batch_size if td is None else td.batch_size + if td is None or td.is_empty(): + td = self.generator(batch_size=batch_size) + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + self.to(td.device) + return super().reset(td, batch_size=batch_size) + def _torchrl_step(self, td: TensorDict) -> TensorDict: """See :meth:`super().step` for more details. This is the usual way to do it in TorchRL, but inefficient in our case @@ -167,6 +177,15 @@ def _make_spec(self, td_params: TensorDict = None): raise NotImplementedError def get_reward(self, td, actions) -> TensorDict: + """Function to compute the reward. Can be called by the agent to compute the reward of the current state + This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks + """ + if self.check_solution: + self.check_solution_validity(td, actions) + return self._get_reward(td, actions) + + @abc.abstractmethod + def _get_reward(self, td, actions) -> TensorDict: """Function to compute the reward. Can be called by the agent to compute the reward of the current state This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks """ @@ -200,7 +219,7 @@ def dataset(self, batch_size=[], phase="train", filename=None): if f is None: if phase != "train": log.warning(f"{phase}_file not set. Generating dataset instead") - td = self.generate_data(batch_size) + td = self.generator(batch_size) else: log.info(f"Loading {phase} dataset from {f}") if phase == "train": @@ -222,14 +241,10 @@ def dataset(self, batch_size=[], phase="train", filename=None): f"Provided file name {f} not found. Make sure to provide a file in the right path first or " f"unset {phase}_file to generate data automatically instead" ) - td = self.generate_data(batch_size) + td = self.generator(batch_size) return self.dataset_cls(td) - def generate_data(self, batch_size): - """Dataset generation""" - raise NotImplementedError - def transform(self): """Used for converting TensorDict variables (such as with torch.cat) efficiently https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.Transform.html diff --git a/rl4co/envs/common/utils.py b/rl4co/envs/common/utils.py index 41cecb55..7ee66602 100644 --- a/rl4co/envs/common/utils.py +++ b/rl4co/envs/common/utils.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import Optional, Callable, Union import torch -from tensordict.tensordict import TensorDictBase +from torch.distributions import Uniform, Normal, Exponential, Poisson +from tensordict.tensordict import TensorDictBase, TensorDict from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec @@ -44,3 +45,56 @@ def _getstate_env(self): state = self.__dict__.copy() del state["rng"] return state + + +class Generator(): + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __call__(self, batch_size) -> TensorDict: + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + return self._generate(batch_size) + + def _generate(self, batch_size, **kwargs) -> TensorDict: + raise NotImplementedError + + +def get_sampler( + val_name: str, + distribution: Union[int, float, str, type, Callable], + min_val: float, + max_val: float, + **kwargs + ): + """Get the sampler for the variable with the given distribution + Args: + val_name: Name of the variable + distribution: int/float value (as constant distribution), or string with the distribution name (supporting + uniform, normal, exponential, and poisson) or PyTorch Distribution type or a callable function that + returns a PyTorch Distribution + min_val: Minimum value for the variable, used for Uniform distribution + max_val: Maximum value for the variable, used for Uniform distribution + kwargs: Additional arguments for the distribution + """ + if isinstance(distribution, (int, float)): + return Uniform(low=distribution, high=distribution) + elif distribution == "center": # Depot + return Uniform(low=(max_val-min_val)/2, high=(max_val-min_val)/2) + elif distribution == "corner": # Depot + return Uniform(low=min_val, high=min_val) + elif distribution == Uniform or distribution == "uniform": + return Uniform(low=min_val, high=max_val) + elif distribution == Normal or distribution == "normal": + assert kwargs.get("mean_"+val_name, None) is not None, "mean is required for Normal distribution" + assert kwargs.get(val_name+"_std", None) is not None, "std is required for Normal distribution" + return Normal(mean=kwargs[val_name+"_mean"], std=kwargs[val_name+"_std"]) + elif distribution == Exponential or distribution == "exponential": + assert kwargs.get(val_name+"_rate", None) is not None, "rate is required for Exponential/Poisson distribution" + return Exponential(rate=kwargs[val_name+"_rate"]) + elif distribution == Poisson or distribution == "poisson": + assert kwargs.get(val_name+"_rate", None) is not None, "rate is required for Exponential/Poisson distribution" + return Poisson(rate=kwargs[val_name+"_rate"]) + elif isinstance(distribution, Callable): + return distribution(**kwargs) + else: + raise ValueError(f"Invalid distribution type of {distribution}") diff --git a/rl4co/envs/eda/__init__.py b/rl4co/envs/eda/__init__.py index da7f45e2..ce8f2f8c 100644 --- a/rl4co/envs/eda/__init__.py +++ b/rl4co/envs/eda/__init__.py @@ -1,2 +1,2 @@ -from rl4co.envs.eda.dpp import DPPEnv -from rl4co.envs.eda.mdpp import MDPPEnv +from rl4co.envs.eda.dpp.env import DPPEnv +from rl4co.envs.eda.mdpp.env import MDPPEnv diff --git a/rl4co/envs/eda/dpp.py b/rl4co/envs/eda/dpp.py deleted file mode 100644 index 05410af3..00000000 --- a/rl4co/envs/eda/dpp.py +++ /dev/null @@ -1,422 +0,0 @@ -import os -import zipfile - -from typing import Optional - -import numpy as np -import torch - -from robust_downloader import download -from tensordict.tensordict import TensorDict -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) - -from rl4co.data.utils import load_npz_to_tensordict -from rl4co.envs.common.base import RL4COEnvBase -from rl4co.utils.pylogger import get_pylogger - -log = get_pylogger(__name__) - - -class DPPEnv(RL4COEnvBase): - """Decap placement problem as done in DevFormer paper: https://arxiv.org/abs/2205.13225 - - The environment is a 10x10 grid with 100 locations containing either a probing port or a keepout region. - The goal is to place decaps (decoupling capacitors) to maximize the impedance suppression at the probing port. - Decaps cannot be placed in keepout regions or at the probing port and the number of decaps is limited. - - Args: - min_loc: Minimum location value. Defaults to 0. - max_loc: Maximum location value. Defaults to 1. - num_keepout_min: Minimum number of keepout regions. Defaults to 1. - num_keepout_max: Maximum number of keepout regions. Defaults to 50. - max_decaps: Maximum number of decaps. Defaults to 20. - data_dir: Directory to store data. Defaults to "data/dpp/". - This can be downloaded from this [url](https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95). - chip_file: Name of the chip file. Defaults to "10x10_pkg_chip.npy". - decap_file: Name of the decap file. Defaults to "01nF_decap.npy". - freq_file: Name of the frequency file. Defaults to "freq_201.npy". - url: URL to download data from. Defaults to None. - td_params: TensorDict parameters. Defaults to None. - """ - - name = "dpp" - - def __init__( - self, - *, - min_loc: float = 0, - max_loc: float = 1, - num_keepout_min: int = 1, - num_keepout_max: int = 50, - max_decaps: int = 20, - data_dir: str = "data/dpp/", - chip_file: str = "10x10_pkg_chip.npy", - decap_file: str = "01nF_decap.npy", - freq_file: str = "freq_201.npy", - url: str = None, - td_params: TensorDict = None, - **kwargs, - ): - kwargs["data_dir"] = data_dir - super().__init__(**kwargs) - - self.url = ( - "https://github.com/kaist-silab/devformer/raw/main/data/data.zip" - if url is None - else url - ) - self.backup_url = ( - "https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95" - ) - self._load_dpp_data(chip_file, decap_file, freq_file) - self.min_loc = min_loc - self.max_loc = max_loc - self.num_keepout_min = num_keepout_min - self.num_keepout_max = num_keepout_max - self.max_decaps = max_decaps - - assert ( - num_keepout_min <= num_keepout_max - ), "num_keepout_min must be <= num_keepout_max" - assert ( - num_keepout_max <= self.size**2 - ), "num_keepout_max must be <= size * size (total number of locations)" - - self._make_spec(td_params) - - def _step(self, td: TensorDict) -> TensorDict: - current_node = td["action"] - - # Set available to 0 (i.e., already placed) if the current node is the first node - available = td["action_mask"].scatter( - -1, current_node.unsqueeze(-1).expand_as(td["action_mask"]), 0 - ) - - # Set done if i is greater than max_decaps - done = td["i"] >= self.max_decaps - 1 - - # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here - reward = torch.zeros_like(done) - - td.update( - { - "i": td["i"] + 1, - "action_mask": available, - "reward": reward, - "done": done, - } - ) - return td - - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: - # Initialize locations - if batch_size is None: - batch_size = self.batch_size if td is None else td.batch_size - device = td.device if td is not None else self.device - self.to(device) - - # We allow loading the initial observation from a dataset for faster loading - if td is None: - td = self.generate_data(batch_size=batch_size) - - # Other variables - i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=self.device) - - return TensorDict( - { - "locs": td["locs"], - "probe": td["probe"], - "i": i, - "action_mask": td["action_mask"], - "keepout": ~td["action_mask"], - }, - batch_size=batch_size, - ) - - def _make_spec(self, td_params): - """Make the observation and action specs from the parameters""" - self.observation_spec = CompositeSpec( - locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.size**2, 2), - dtype=torch.float32, - ), - probe=UnboundedDiscreteTensorSpec( - shape=(1), - dtype=torch.int64, - ), - keepout=UnboundedDiscreteTensorSpec( - shape=(self.size**2), - dtype=torch.bool, - ), - i=UnboundedDiscreteTensorSpec( - shape=(1), - dtype=torch.int64, - ), - action_mask=UnboundedDiscreteTensorSpec( - shape=(self.size**2), - dtype=torch.bool, - ), - shape=(), - ) - self.action_spec = BoundedTensorSpec( - shape=(1,), - dtype=torch.int64, - low=0, - high=self.size**2, - ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) - self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - - def get_reward(self, td, actions): - """ - We call the reward function with the final sequence of actions to get the reward - Calling per-step would be very time consuming due to decap simulation - """ - # We do the operation in a batch - if len(td.batch_size) == 0: - td = td.unsqueeze(0) - actions = actions.unsqueeze(0) - probes = td["probe"] - reward = torch.stack( - [self._decap_simulator(p, a) for p, a in zip(probes, actions)] - ) - return reward - - def generate_data(self, batch_size): - """ - Generate initial observations for the environment with locations, probe, and action mask - Action_mask eliminates the keepout regions and the probe location, and is updated to eliminate placed decaps - """ - m = n = self.size - # if int, convert to list and make it a batch for easier generation - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - batched = len(batch_size) > 0 - bs = [1] if not batched else batch_size - - # Create a list of locs on a grid - locs = torch.meshgrid( - torch.arange(m, device=self.device), torch.arange(n, device=self.device) - ) - locs = torch.stack(locs, dim=-1).reshape(-1, 2) - # normalize the locations by the number of rows and columns - locs = locs / torch.tensor([m, n], dtype=torch.float, device=self.device) - locs = locs[None].expand(*bs, -1, -1) - - # Create available mask - available = torch.ones((*bs, m * n), dtype=torch.bool) - - # Sample probe location from m*n - probe = torch.randint(m * n, size=(*bs, 1)) - available.scatter_(1, probe, False) - - # Sample keepout locations from m*n except probe - num_keepout = torch.randint( - self.num_keepout_min, - self.num_keepout_max, - size=(*bs, 1), - device=self.device, - ) - keepouts = [torch.randperm(m * n)[:k] for k in num_keepout] - for i, (a, k) in enumerate(zip(available, keepouts)): - available[i] = a.scatter(0, k, False) - - return TensorDict( - { - "locs": locs if batched else locs.squeeze(0), - "probe": probe if batched else probe.squeeze(0), - "action_mask": available if batched else available.squeeze(0), - }, - batch_size=batch_size, - ) - - def _decap_placement(self, pi, probe): - device = pi.device - - n = m = self.size # columns and rows - num_decap = torch.numel(pi) - z1 = self.raw_pdn.to(device) - - decap = self.decap.reshape(-1).to(device) - z2 = torch.zeros( - (self.num_freq, num_decap, num_decap), dtype=torch.float32, device=device - ) - - qIndx = torch.arange(num_decap, device=device) - - z2[:, qIndx, qIndx] = torch.abs(decap)[:, None].repeat_interleave( - z2[:, qIndx, qIndx].shape[-1], dim=-1 - ) - pIndx = pi.long() - - aIndx = torch.arange(len(z1[0]), device=device) - aIndx = torch.tensor( - list(set(aIndx.tolist()) - set(pIndx.tolist())), device=device - ) - - z1aa = z1[:, aIndx, :][:, :, aIndx] - z1ap = z1[:, aIndx, :][:, :, pIndx] - z1pa = z1[:, pIndx, :][:, :, aIndx] - z1pp = z1[:, pIndx, :][:, :, pIndx] - z2qq = z2[:, qIndx, :][:, :, qIndx] - - zout = z1aa - torch.matmul(torch.matmul(z1ap, torch.inverse(z1pp + z2qq)), z1pa) - - idx = torch.arange(n * m, device=device) - mask = torch.zeros(n * m, device=device).bool() - mask[pi] = True - mask = mask & (idx < probe) - probe -= mask.sum().item() - - zout = zout[:, probe, probe] - return zout - - def _decap_model(self, z_initial, z_final): - impedance_gap = torch.zeros(self.num_freq, device=self.device) - - impedance_gap = z_initial - z_final - reward = torch.sum(impedance_gap * 1000000000 / self.freq.to(self.device)) - - reward = reward / 10 - return reward - - def _initial_impedance(self, probe): - zout = self.raw_pdn.to(self.device)[:, probe, probe] - return zout - - def _decap_simulator(self, probe, solution, keepout=None): - self.to(self.device) - - probe = probe.item() - - assert len(solution) == len( - torch.unique(solution) - ), "An Element of Decap Sequence must be Unique" - - if keepout is not None: - keepout = torch.tensor(keepout) - intersect = torch.tensor(list(set(solution.tolist()) & set(keepout.tolist()))) - assert len(intersect) == 0, "Decap must be not placed at the keepout region" - - z_initial = self._initial_impedance(probe) - z_initial = torch.abs(z_initial) - z_final = self._decap_placement(solution, probe) - z_final = torch.abs(z_final) - reward = self._decap_model(z_initial, z_final) - return reward - - def _load_dpp_data(self, chip_file, decap_file, freq_file): - def _load_file(fpath): - f = os.path.join(self.data_dir, fpath) - if not os.path.isfile(f): - self._download_data() - with open(f, "rb") as f_: - return torch.from_numpy(np.load(f_)).to(self.device) - - self.raw_pdn = _load_file(chip_file) # [num_freq, size^2, size^2] - self.decap = _load_file(decap_file).to(torch.complex64) # [num_freq, 1, 1] - self.freq = _load_file(freq_file) # [num_freq] - self.size = int(np.sqrt(self.raw_pdn.shape[-1])) - self.num_freq = self.freq.shape[0] - - def _download_data(self): - log.info("Downloading data...") - try: - download(self.url, self.data_dir, "data.zip") - except Exception: - log.error( - f"Download from main url {self.url} failed. Trying backup url {self.backup_url}..." - ) - download(self.backup_url, self.data_dir, "data.zip") - log.info("Download complete. Unzipping...") - zipfile.ZipFile(os.path.join(self.data_dir, "data.zip"), "r").extractall( - self.data_dir - ) - log.info("Unzip complete. Removing zip file") - os.remove(os.path.join(self.data_dir, "data.zip")) - - def load_data(self, fpath, batch_size=[]): - data = load_npz_to_tensordict(fpath) - # rename key if necessary (old dpp version) - if "observation" in data.keys(): - data["locs"] = data.pop("observation") - return data - - def render(self, decaps, probe, action_mask, ax=None, legend=True): - """ - Plot a grid of 1x1 squares representing the environment. - The keepout regions are the action_mask - decaps - probe - """ - import matplotlib.pyplot as plt - - settings = { - 0: {"color": "white", "label": "available"}, - 1: {"color": "grey", "label": "keepout"}, - 2: {"color": "tab:red", "label": "probe"}, - 3: {"color": "tab:blue", "label": "decap"}, - } - - nonzero_indices = torch.nonzero(~action_mask, as_tuple=True)[0] - keepout = torch.cat([nonzero_indices, probe, decaps.squeeze(-1)]) - unique_elements, counts = torch.unique(keepout, return_counts=True) - keepout = unique_elements[counts == 1] - - if ax is None: - fig, ax = plt.subplots(1, 1, figsize=(6, 6)) - - grid = np.meshgrid(np.arange(0, self.size), np.arange(0, self.size)) - grid = np.stack(grid, axis=-1) - - # Add new dimension to grid filled up with 0s - grid = np.concatenate([grid, np.zeros((self.size, self.size, 1))], axis=-1) - - # Add keepout = 1 - grid[keepout // self.size, keepout % self.size, 2] = 1 - # Add probe = 2 - grid[probe // self.size, probe % self.size, 2] = 2 - # Add decaps = 3 - grid[decaps // self.size, decaps % self.size, 2] = 3 - - xdim, ydim = grid.shape[0], grid.shape[1] - ax.imshow(np.zeros((xdim, ydim)), cmap="gray") - - ax.set_xlim(0, xdim) - ax.set_ylim(0, ydim) - - for i in range(xdim): - for j in range(ydim): - color = settings[grid[i, j, 2]]["color"] - x, y = grid[i, j, 0], grid[i, j, 1] - ax.add_patch(plt.Rectangle((x, y), 1, 1, color=color, linestyle="-")) - - # Add grid with 1x1 squares - ax.grid( - which="major", axis="both", linestyle="-", color="k", linewidth=1, alpha=0.5 - ) - # set 10 ticks - ax.set_xticks(np.arange(0, xdim, 1)) - ax.set_yticks(np.arange(0, ydim, 1)) - - # Invert y axis - ax.invert_yaxis() - - # Add legend - if legend: - num_unique = 4 - handles = [ - plt.Rectangle((0, 0), 1, 1, color=settings[i]["color"]) - for i in range(num_unique) - ] - ax.legend( - handles, - [settings[i]["label"] for i in range(num_unique)], - ncol=num_unique, - loc="upper center", - bbox_to_anchor=(0.5, 1.1), - ) diff --git a/rl4co/envs/eda/dpp/env.py b/rl4co/envs/eda/dpp/env.py new file mode 100644 index 00000000..c8d89430 --- /dev/null +++ b/rl4co/envs/eda/dpp/env.py @@ -0,0 +1,260 @@ +import os +import zipfile + +from typing import Optional + +import numpy as np +import torch + +from robust_downloader import download +from tensordict.tensordict import TensorDict +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) + +from rl4co.data.utils import load_npz_to_tensordict +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.utils.pylogger import get_pylogger + +from .generator import DPPGenerator +from .render import render + +log = get_pylogger(__name__) + + +class DPPEnv(RL4COEnvBase): + """Decap Placement Problem (DPP) as done in DevFormer paper: https://arxiv.org/abs/2205.13225 + + The environment is a 10x10 grid with 100 locations containing either a probing port or a keepout region. + The goal is to place decaps (decoupling capacitors) to maximize the impedance suppression at the probing port. + Decaps cannot be placed in keepout regions or at the probing port and the number of decaps is limited. + + Observations: + - locations of the probing port and keepout regions + - current decap placement + - remaining decaps + + Constraints: + - decaps cannot be placed at the probing port or keepout regions + - the number of decaps is limited + + Finish Condition: + - the number of decaps exceeds the limit + + Reward: + - the impedance suppression at the probing port + + Args: + generator: DPPGenerator instance as the data generator + generator_params: parameters for the generator + """ + + name = "dpp" + + def __init__( + self, + generator: DPPGenerator = None, + generator_params: dict = {}, + **kwargs, + ): + super().__init__(**kwargs) + if generator is None: + generator = DPPGenerator(**generator_params) + self.generator = generator + + self.max_decaps = self.generator.max_decaps + self.size = self.generator.size + self.raw_pdn = self.generator.raw_pdn + self.decap = self.generator.decap + self.freq = self.generator.freq + self.num_freq = self.generator.num_freq + self.data_dir = self.generator.data_dir + + self._make_spec(self.generator) + + def _step(self, td: TensorDict) -> TensorDict: + current_node = td["action"] + + # Set available to 0 (i.e., already placed) if the current node is the first node + available = td["action_mask"].scatter( + -1, current_node.unsqueeze(-1).expand_as(td["action_mask"]), 0 + ) + + # Set done if i is greater than max_decaps + done = td["i"] >= self.max_decaps - 1 + + # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here + reward = torch.zeros_like(done) + + td.update( + { + "i": td["i"] + 1, + "action_mask": available, + "reward": reward, + "done": done, + } + ) + return td + + def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + device = td.device + + # Other variables + i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) + + return TensorDict( + { + "locs": td["locs"], + "probe": td["probe"], + "i": i, + "action_mask": td["action_mask"], + "keepout": ~td["action_mask"], + }, + batch_size=batch_size, + ) + + def _make_spec(self, generator: DPPGenerator): + self.observation_spec = CompositeSpec( + locs=BoundedTensorSpec( + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.size**2, 2), + dtype=torch.float32, + ), + probe=UnboundedDiscreteTensorSpec( + shape=(1), + dtype=torch.int64, + ), + keepout=UnboundedDiscreteTensorSpec( + shape=(generator.size**2), + dtype=torch.bool, + ), + i=UnboundedDiscreteTensorSpec( + shape=(1), + dtype=torch.int64, + ), + action_mask=UnboundedDiscreteTensorSpec( + shape=(generator.size**2), + dtype=torch.bool, + ), + shape=(), + ) + self.action_spec = BoundedTensorSpec( + shape=(1,), + dtype=torch.int64, + low=0, + high=generator.size**2, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) + + def _get_reward(self, td, actions): + """ + We call the reward function with the final sequence of actions to get the reward + Calling per-step would be very time consuming due to decap simulation + """ + # We do the operation in a batch + if len(td.batch_size) == 0: + td = td.unsqueeze(0) + actions = actions.unsqueeze(0) + probes = td["probe"] + reward = torch.stack( + [self._decap_simulator(p, a) for p, a in zip(probes, actions)] + ) + return reward + + @staticmethod + def check_solution_validity(td: TensorDict, actions: torch.Tensor): + assert True, "Not implemented" + + def _decap_placement(self, pi, probe): + device = pi.device + + n = m = self.size # columns and rows + num_decap = torch.numel(pi) + z1 = self.raw_pdn.to(device) + + decap = self.decap.reshape(-1).to(device) + z2 = torch.zeros( + (self.num_freq, num_decap, num_decap), dtype=torch.float32, device=device + ) + + qIndx = torch.arange(num_decap, device=device) + + z2[:, qIndx, qIndx] = torch.abs(decap)[:, None].repeat_interleave( + z2[:, qIndx, qIndx].shape[-1], dim=-1 + ) + pIndx = pi.long() + + aIndx = torch.arange(len(z1[0]), device=device) + aIndx = torch.tensor( + list(set(aIndx.tolist()) - set(pIndx.tolist())), device=device + ) + + z1aa = z1[:, aIndx, :][:, :, aIndx] + z1ap = z1[:, aIndx, :][:, :, pIndx] + z1pa = z1[:, pIndx, :][:, :, aIndx] + z1pp = z1[:, pIndx, :][:, :, pIndx] + z2qq = z2[:, qIndx, :][:, :, qIndx] + + zout = z1aa - torch.matmul(torch.matmul(z1ap, torch.inverse(z1pp + z2qq)), z1pa) + + idx = torch.arange(n * m, device=device) + mask = torch.zeros(n * m, device=device).bool() + mask[pi] = True + mask = mask & (idx < probe) + probe -= mask.sum().item() + + zout = zout[:, probe, probe] + return zout + + def _decap_model(self, z_initial, z_final): + impedance_gap = torch.zeros(self.num_freq, device=self.device) + + impedance_gap = z_initial - z_final + reward = torch.sum(impedance_gap * 1000000000 / self.freq.to(self.device)) + + reward = reward / 10 + return reward + + def _initial_impedance(self, probe): + zout = self.raw_pdn.to(self.device)[:, probe, probe] + return zout + + def _decap_simulator(self, probe, solution, keepout=None): + self.to(self.device) + + probe = probe.item() + + assert len(solution) == len( + torch.unique(solution) + ), "An Element of Decap Sequence must be Unique" + + if keepout is not None: + keepout = torch.tensor(keepout) + intersect = torch.tensor(list(set(solution.tolist()) & set(keepout.tolist()))) + assert len(intersect) == 0, "Decap must be not placed at the keepout region" + + z_initial = self._initial_impedance(probe) + z_initial = torch.abs(z_initial) + z_final = self._decap_placement(solution, probe) + z_final = torch.abs(z_final) + reward = self._decap_model(z_initial, z_final) + return reward + + def _load_dpp_data(self, chip_file, decap_file, freq_file): + def _load_file(fpath): + f = os.path.join(self.generator.data_dir, fpath) + if not os.path.isfile(f): + self._download_data() + with open(f, "rb") as f_: + return torch.from_numpy(np.load(f_)).to(self.device) + + self.raw_pdn = _load_file(chip_file) # [num_freq, size^2, size^2] + self.decap = _load_file(decap_file).to(torch.complex64) # [num_freq, 1, 1] + self.freq = _load_file(freq_file) # [num_freq] + self.size = int(np.sqrt(self.raw_pdn.shape[-1])) + self.num_freq = self.freq.shape[0] diff --git a/rl4co/envs/eda/dpp/generator.py b/rl4co/envs/eda/dpp/generator.py new file mode 100644 index 00000000..d34b8e7c --- /dev/null +++ b/rl4co/envs/eda/dpp/generator.py @@ -0,0 +1,169 @@ +import os +import zipfile +from typing import Union, Callable + +import torch +import numpy as np + +from robust_downloader import download +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.data.utils import load_npz_to_tensordict +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + + +class DPPGenerator(Generator): + """Data generator for the Decap Placement Problem (DPP). + + Args: + min_loc: Minimum location value. Defaults to 0. + max_loc: Maximum location value. Defaults to 1. + num_keepout_min: Minimum number of keepout regions. Defaults to 1. + num_keepout_max: Maximum number of keepout regions. Defaults to 50. + max_decaps: Maximum number of decaps. Defaults to 20. + data_dir: Directory to store data. Defaults to "data/dpp/". + This can be downloaded from this [url](https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95). + chip_file: Name of the chip file. Defaults to "10x10_pkg_chip.npy". + decap_file: Name of the decap file. Defaults to "01nF_decap.npy". + freq_file: Name of the frequency file. Defaults to "freq_201.npy". + url: URL to download data from. Defaults to None. + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + depot [batch_size, 2]: location of the depot + demand [batch_size, num_loc]: demand of each customer + capacity [batch_size]: capacity of the vehicle + """ + def __init__( + self, + min_loc: float = 0.0, + max_loc: float = 1.0, + num_keepout_min: int = 1, + num_keepout_max: int = 50, + max_decaps: int = 20, + data_dir: str = "data/dpp/", + chip_file: str = "10x10_pkg_chip.npy", + decap_file: str = "01nF_decap.npy", + freq_file: str = "freq_201.npy", + url: str = None, + **unused_kwargs + ): + self.min_loc = min_loc + self.max_loc = max_loc + self.num_keepout_min = num_keepout_min + self.num_keepout_max = num_keepout_max + self.max_decaps = max_decaps + self.data_dir = data_dir + + # DPP environment doen't have any other kwargs + if len(unused_kwargs) > 0: + log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") + + + # Download and load the data from online dataset + self.url = ( + "https://github.com/kaist-silab/devformer/raw/main/data/data.zip" + if url is None + else url + ) + self.backup_url = ( + "https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95" + ) + self._load_dpp_data(chip_file, decap_file, freq_file) + + # Check the validity of the keepout parameters + assert ( + num_keepout_min <= num_keepout_max + ), "num_keepout_min must be <= num_keepout_max" + assert ( + num_keepout_max <= self.size**2 + ), "num_keepout_max must be <= size * size (total number of locations)" + + def _generate(self, batch_size) -> TensorDict: + """ + Generate initial observations for the environment with locations, probe, and action mask + Action_mask eliminates the keepout regions and the probe location, and is updated to eliminate placed decaps + """ + m = n = self.size + # if int, convert to list and make it a batch for easier generation + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + batched = len(batch_size) > 0 + bs = [1] if not batched else batch_size + + # Create a list of locs on a grid + locs = torch.meshgrid( + torch.arange(m), torch.arange(n) + ) + locs = torch.stack(locs, dim=-1).reshape(-1, 2) + # normalize the locations by the number of rows and columns + locs = locs / torch.tensor([m, n], dtype=torch.float) + locs = locs[None].expand(*bs, -1, -1) + + # Create available mask + available = torch.ones((*bs, m * n), dtype=torch.bool) + + # Sample probe location from m*n + probe = torch.randint(m * n, size=(*bs, 1)) + available.scatter_(1, probe, False) + + # Sample keepout locations from m*n except probe + num_keepout = torch.randint( + self.num_keepout_min, + self.num_keepout_max, + size=(*bs, 1), + ) + keepouts = [torch.randperm(m * n)[:k] for k in num_keepout] + for i, (a, k) in enumerate(zip(available, keepouts)): + available[i] = a.scatter(0, k, False) + + return TensorDict( + { + "locs": locs if batched else locs.squeeze(0), + "probe": probe if batched else probe.squeeze(0), + "action_mask": available if batched else available.squeeze(0), + }, + batch_size=batch_size, + ) + + def _load_dpp_data(self, chip_file, decap_file, freq_file): + def _load_file(fpath): + f = os.path.join(self.data_dir, fpath) + if not os.path.isfile(f): + self._download_data() + with open(f, "rb") as f_: + return torch.from_numpy(np.load(f_)) + + self.raw_pdn = _load_file(chip_file) # [num_freq, size^2, size^2] + self.decap = _load_file(decap_file).to(torch.complex64) # [num_freq, 1, 1] + self.freq = _load_file(freq_file) # [num_freq] + self.size = int(np.sqrt(self.raw_pdn.shape[-1])) + self.num_freq = self.freq.shape[0] + + def _download_data(self): + log.info("Downloading data...") + try: + download(self.url, self.data_dir, "data.zip") + except Exception: + log.error( + f"Download from main url {self.url} failed. Trying backup url {self.backup_url}..." + ) + download(self.backup_url, self.data_dir, "data.zip") + log.info("Download complete. Unzipping...") + zipfile.ZipFile(os.path.join(self.data_dir, "data.zip"), "r").extractall( + self.data_dir + ) + log.info("Unzip complete. Removing zip file") + os.remove(os.path.join(self.data_dir, "data.zip")) + + def load_data(self, fpath, batch_size=[]): + data = load_npz_to_tensordict(fpath) + # rename key if necessary (old dpp version) + if "observation" in data.keys(): + data["locs"] = data.pop("observation") + return data diff --git a/rl4co/envs/eda/dpp/render.py b/rl4co/envs/eda/dpp/render.py new file mode 100644 index 00000000..fec5ecba --- /dev/null +++ b/rl4co/envs/eda/dpp/render.py @@ -0,0 +1,84 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import cm, colormaps + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(self, decaps, probe, action_mask, ax=None, legend=True): + """ + Plot a grid of 1x1 squares representing the environment. + The keepout regions are the action_mask - decaps - probe + """ + import matplotlib.pyplot as plt + + settings = { + 0: {"color": "white", "label": "available"}, + 1: {"color": "grey", "label": "keepout"}, + 2: {"color": "tab:red", "label": "probe"}, + 3: {"color": "tab:blue", "label": "decap"}, + } + + nonzero_indices = torch.nonzero(~action_mask, as_tuple=True)[0] + keepout = torch.cat([nonzero_indices, probe, decaps.squeeze(-1)]) + unique_elements, counts = torch.unique(keepout, return_counts=True) + keepout = unique_elements[counts == 1] + + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + + grid = np.meshgrid(np.arange(0, self.size), np.arange(0, self.size)) + grid = np.stack(grid, axis=-1) + + # Add new dimension to grid filled up with 0s + grid = np.concatenate([grid, np.zeros((self.size, self.size, 1))], axis=-1) + + # Add keepout = 1 + grid[keepout // self.size, keepout % self.size, 2] = 1 + # Add probe = 2 + grid[probe // self.size, probe % self.size, 2] = 2 + # Add decaps = 3 + grid[decaps // self.size, decaps % self.size, 2] = 3 + + xdim, ydim = grid.shape[0], grid.shape[1] + ax.imshow(np.zeros((xdim, ydim)), cmap="gray") + + ax.set_xlim(0, xdim) + ax.set_ylim(0, ydim) + + for i in range(xdim): + for j in range(ydim): + color = settings[grid[i, j, 2]]["color"] + x, y = grid[i, j, 0], grid[i, j, 1] + ax.add_patch(plt.Rectangle((x, y), 1, 1, color=color, linestyle="-")) + + # Add grid with 1x1 squares + ax.grid( + which="major", axis="both", linestyle="-", color="k", linewidth=1, alpha=0.5 + ) + # set 10 ticks + ax.set_xticks(np.arange(0, xdim, 1)) + ax.set_yticks(np.arange(0, ydim, 1)) + + # Invert y axis + ax.invert_yaxis() + + # Add legend + if legend: + num_unique = 4 + handles = [ + plt.Rectangle((0, 0), 1, 1, color=settings[i]["color"]) + for i in range(num_unique) + ] + ax.legend( + handles, + [settings[i]["label"] for i in range(num_unique)], + ncol=num_unique, + loc="upper center", + bbox_to_anchor=(0.5, 1.1), + ) diff --git a/rl4co/envs/eda/mdpp.py b/rl4co/envs/eda/mdpp.py deleted file mode 100644 index cee75d4b..00000000 --- a/rl4co/envs/eda/mdpp.py +++ /dev/null @@ -1,345 +0,0 @@ -from typing import Optional - -import numpy as np -import torch - -from tensordict.tensordict import TensorDict -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) - -from rl4co.envs.eda.dpp import DPPEnv -from rl4co.utils.pylogger import get_pylogger - -log = get_pylogger(__name__) - - -class MDPPEnv(DPPEnv): - """Multiple decap placement problem (mDPP) environment - This is a modified version of the DPP environment where we allow multiple probing ports - The reward can be calculated as: - - minmax: min of the max of the decap scores - - meansum: mean of the sum of the decap scores - The minmax is more challenging as it requires to find the best decap location for the worst case - - Args: - num_probes_min: minimum number of probes - num_probes_max: maximum number of probes - reward_type: reward type, either minmax or meansum - td_params: TensorDict parameters - """ - - name = "mdpp" - - def __init__( - self, - *, - num_probes_min: int = 2, - num_probes_max: int = 5, - reward_type: str = "minmax", - td_params: TensorDict = None, - **kwargs, - ): - super().__init__(**kwargs) - self.num_probes_min = num_probes_min - self.num_probes_max = num_probes_max - assert reward_type in [ - "minmax", - "meansum", - ], "reward_type must be minmax or meansum" - self.reward_type = reward_type - self._make_spec(td_params) - - def _step(self, td: TensorDict) -> TensorDict: - # Step function is the same as DPPEnv, only masking changes - return super()._step(td) - - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: - # Reset function is the same as DPPEnv, only masking changes due to probes - td_reset = super()._reset(td, batch_size=batch_size) - - # Action mask is 0 if both action_mask (e.g. keepout) and probe are 0 - action_mask = torch.logical_and(td_reset["action_mask"], ~td_reset["probe"]) - # Keepout regions are the inverse of action_mask - td_reset.update( - { - "keepout": ~td_reset["action_mask"], - "action_mask": action_mask, - } - ) - return td_reset - - def _make_spec(self, td_params): - """Make the observation and action specs from the parameters""" - self.observation_spec = CompositeSpec( - locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.size**2, 2), - dtype=torch.float32, - ), - probe=UnboundedDiscreteTensorSpec( - shape=(self.size**2), - dtype=torch.bool, - ), # probe is a boolean of multiple locations (1=probe, 0=not probe) - keepout=UnboundedDiscreteTensorSpec( - shape=(self.size**2), - dtype=torch.bool, - ), - i=UnboundedDiscreteTensorSpec( - shape=(1), - dtype=torch.int64, - ), - action_mask=UnboundedDiscreteTensorSpec( - shape=(self.size**2), - dtype=torch.bool, - ), - shape=(), - ) - self.action_spec = BoundedTensorSpec( - shape=(1,), - dtype=torch.int64, - low=0, - high=self.size**2, - ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) - self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - - def get_reward(self, td, actions): - """We call the reward function with the final sequence of actions to get the reward - Calling per-step would be very time consuming due to decap simulation - """ - # We do the operation in a batch - if len(td.batch_size) == 0: - td = td.unsqueeze(0) - actions = actions.unsqueeze(0) - - # Reward calculation is expensive since we need to run decap simulation (not vectorizable) - reward = torch.stack( - [ - self._single_env_reward(td_single, action) - for td_single, action in zip(td, actions) - ] - ) - return reward - - def _single_env_reward(self, td, actions): - """Get reward for single environment. We""" - - list_probe = torch.nonzero(td["probe"]).squeeze() - scores = torch.zeros_like(list_probe, dtype=torch.float32) - for i, probe in enumerate(list_probe): - # Get the decap scores for the probe location - scores[i] = self._decap_simulator(probe, actions) - # If minmax, return min of max decap scores else mean - return scores.min() if self.reward_type == "minmax" else scores.mean() - - def generate_data(self, batch_size): - """ - Generate initial observations for the environment with locations, probe, and action mask - Action_mask eliminates the keepout regions and the probe location, and is updated to eliminate placed decaps - """ - - m = n = self.size - # if int, convert to list and make it a batch for easier generation - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - batched = len(batch_size) > 0 - bs = [1] if not batched else batch_size - - # Create a list of locs on a grid - locs = torch.meshgrid(torch.arange(m), torch.arange(n)) - locs = torch.stack(locs, dim=-1).reshape(-1, 2) - # normalize the locations by the number of rows and columns - locs = locs / torch.tensor([m, n], dtype=torch.float) - locs = locs[None].expand(*bs, -1, -1) - - # Create available mask - available = torch.ones((*bs, m * n), dtype=torch.bool) - - # Sample probe location from m*n - probe = torch.randint(m * n, size=(*bs, 1)) - available.scatter_(1, probe, False) - - # Sample probe locatins - num_probe = torch.randint( - self.num_probes_min, - self.num_probes_max, - size=(*bs, 1), - ) - probe = [torch.randperm(m * n)[:p] for p in num_probe] - probes = torch.zeros((*bs, m * n), dtype=torch.bool) - for i, (a, p) in enumerate(zip(available, probe)): - available[i] = a.scatter(0, p, False) - probes[i] = probes[i].scatter(0, p, True) - - # Sample keepout locations from m*n except probe - num_keepout = torch.randint( - self.num_keepout_min, - self.num_keepout_max, - size=(*bs, 1), - ) - keepouts = [torch.randperm(m * n)[:k] for k in num_keepout] - for i, (a, k) in enumerate(zip(available, keepouts)): - available[i] = a.scatter(0, k, False) - - return TensorDict( - { - "locs": locs if batched else locs.squeeze(0), - "probe": probes if batched else probes.squeeze(0), - "action_mask": available if batched else available.squeeze(0), - }, - batch_size=batch_size, - ) - - def render(self, td, actions=None, ax=None, legend=True, settings=None): - """Plot a grid of squares representing the environment. - The keepout regions are the action_mask - decaps - probe - """ - - import matplotlib.pyplot as plt - - from matplotlib.lines import Line2D - from matplotlib.patches import Annulus, Rectangle, RegularPolygon - - if settings is None: - settings = { - "available": {"color": "white", "label": "available"}, - "keepout": {"color": "grey", "label": "keepout"}, - "probe": {"color": "tab:red", "label": "probe"}, - "decap": {"color": "tab:blue", "label": "decap"}, - } - - def draw_capacitor(ax, x, y, color="black"): - # Backgrund rectangle: same as color but with alpha=0.5 - ax.add_patch(Rectangle((x, y), 1, 1, color=color, alpha=0.5)) - - # Create the plates of the capacitor - plate_width, plate_height = ( - 0.3, - 0.1, - ) # Width and height switched to make vertical - plate_gap = 0.2 - plate1 = Rectangle( - (x + 0.5 - plate_width / 2, y + 0.5 - plate_height - plate_gap / 2), - plate_width, - plate_height, - color=color, - ) - plate2 = Rectangle( - (x + 0.5 - plate_width / 2, y + 0.5 + plate_gap / 2), - plate_width, - plate_height, - color=color, - ) - - # Add the plates to the axes - ax.add_patch(plate1) - ax.add_patch(plate2) - - # Add connection lines (wires) - line_length = 0.2 - line1 = Line2D( - [x + 0.5, x + 0.5], - [ - y + 0.5 - plate_height - plate_gap / 2 - line_length, - y + 0.5 - plate_height - plate_gap / 2, - ], - color=color, - ) - line2 = Line2D( - [x + 0.5, x + 0.5], - [ - y + 0.5 + plate_height + plate_gap / 2, - y + 0.5 + plate_height + plate_gap / 2 + line_length, - ], - color=color, - ) - - # Add the lines to the axes - ax.add_line(line1) - ax.add_line(line2) - - def draw_probe(ax, x, y, color="black"): - # Backgrund rectangle: same as color but with alpha=0.5 - ax.add_patch(Rectangle((x, y), 1, 1, color=color, alpha=0.5)) - ax.add_patch(Annulus((x + 0.5, y + 0.5), (0.2, 0.2), 0.1, color=color)) - - def draw_keepout(ax, x, y, color="black"): - # Backgrund rectangle: same as color but with alpha=0.5 - ax.add_patch(Rectangle((x, y), 1, 1, color=color, alpha=0.5)) - ax.add_patch( - RegularPolygon( - (x + 0.5, y + 0.5), numVertices=6, radius=0.45, color=color - ) - ) - - size = self.size - td = td.detach().cpu() - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - - if actions is None: - actions = td.get("action", None) - - # Transform actions from idx to one-hot - decaps = torch.zeros(size**2) - decaps.scatter_(0, actions, 1) - decaps = decaps.reshape(size, size) - - keepout = ~td["action_mask"].reshape(size, size) - probes = td["probe"].reshape(size, size) - - if ax is None: - _, ax = plt.subplots(1, 1, figsize=(6, 6)) - - grid = np.meshgrid(np.arange(0, size), np.arange(0, size)) - grid = np.stack(grid, axis=-1) - - xdim, ydim = grid.shape[0], grid.shape[1] - # ax.imshow(np.zeros((xdim, ydim)), cmap="gray") - - ax.set_xlim(0, xdim) - ax.set_ylim(0, ydim) - - for i in range(xdim): - for j in range(ydim): - x, y = grid[i, j, 0], grid[i, j, 1] - - if decaps[i, j] == 1: - draw_capacitor(ax, x, y, color=settings["decap"]["color"]) - elif probes[i, j] == 1: - draw_probe(ax, x, y, color=settings["probe"]["color"]) - elif keepout[i, j] == 1: - draw_keepout(ax, x, y, color=settings["keepout"]["color"]) - - ax.grid( - which="major", axis="both", linestyle="-", color="k", linewidth=1, alpha=0.5 - ) - # set 10 ticks - ax.set_xticks(np.arange(0, xdim, 1)) - ax.set_yticks(np.arange(0, ydim, 1)) - - # Invert y axis - ax.invert_yaxis() - - # # Add legend - if legend: - colors = [settings[k]["color"] for k in settings.keys()] - labels = [settings[k]["label"] for k in settings.keys()] - handles = [ - plt.Rectangle( - (0, 0), 1, 1, color=c, edgecolor="k", linestyle="-", linewidth=1 - ) - for c in colors - ] - ax.legend( - handles, - [label for label in labels], - ncol=len(colors), - loc="upper center", - bbox_to_anchor=(0.5, 1.1), - ) diff --git a/rl4co/envs/eda/mdpp/env.py b/rl4co/envs/eda/mdpp/env.py new file mode 100644 index 00000000..1da98814 --- /dev/null +++ b/rl4co/envs/eda/mdpp/env.py @@ -0,0 +1,161 @@ +from typing import Optional + +import numpy as np +import torch + +from tensordict.tensordict import TensorDict +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) + +from rl4co.envs.eda.dpp.env import DPPEnv +from rl4co.utils.pylogger import get_pylogger + +from .generator import MDPPGenerator +from .render import render + +log = get_pylogger(__name__) + + +class MDPPEnv(DPPEnv): + """Multiple decap placement problem (mDPP) environment + This is a modified version of the DPP environment where we allow multiple probing ports + + Observations: + - locations of the probing ports and keepout regions + - current decap placement + - remaining decaps + + Constraints: + - decaps cannot be placed at the probing ports or keepout regions + - the number of decaps is limited + + Finish Condition: + - the number of decaps exceeds the limit + + Reward: + - the impedance suppression at the probing ports + + Args: + generator: DPPGenerator instance as the data generator + generator_params: parameters for the generator + reward_type: reward type, either minmax or meansum + - minmax: min of the max of the decap scores + - meansum: mean of the sum of the decap scores + + Note: + The minmax is more challenging as it requires to find the best decap location + for the worst case + """ + + name = "mdpp" + + def __init__( + self, + generator: MDPPGenerator = None, + generator_params: dict = {}, + reward_type: str = "minmax", + **kwargs, + ): + super().__init__(**kwargs) + if generator is None: + generator = MDPPGenerator(**generator_params) + self.generator = generator + + assert reward_type in [ + "minmax", + "meansum", + ], "reward_type must be minmax or meansum" + self.reward_type = reward_type + + self._make_spec(self.generator) + + def _step(self, td: TensorDict) -> TensorDict: + # Step function is the same as DPPEnv, only masking changes + return super()._step(td) + + def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + # Reset function is the same as DPPEnv, only masking changes due to probes + td_reset = super()._reset(td, batch_size=batch_size) + + # Action mask is 0 if both action_mask (e.g. keepout) and probe are 0 + action_mask = torch.logical_and(td_reset["action_mask"], ~td_reset["probe"]) + # Keepout regions are the inverse of action_mask + td_reset.update( + { + "keepout": ~td_reset["action_mask"], + "action_mask": action_mask, + } + ) + return td_reset + + def _make_spec(self, generator: MDPPGenerator): + self.observation_spec = CompositeSpec( + locs=BoundedTensorSpec( + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.size**2, 2), + dtype=torch.float32, + ), + probe=UnboundedDiscreteTensorSpec( + shape=(1), + dtype=torch.int64, + ), + keepout=UnboundedDiscreteTensorSpec( + shape=(generator.size**2), + dtype=torch.bool, + ), + i=UnboundedDiscreteTensorSpec( + shape=(1), + dtype=torch.int64, + ), + action_mask=UnboundedDiscreteTensorSpec( + shape=(generator.size**2), + dtype=torch.bool, + ), + shape=(), + ) + self.action_spec = BoundedTensorSpec( + shape=(1,), + dtype=torch.int64, + low=0, + high=generator.size**2, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) + + def _get_reward(self, td, actions): + """We call the reward function with the final sequence of actions to get the reward + Calling per-step would be very time consuming due to decap simulation + """ + # We do the operation in a batch + if len(td.batch_size) == 0: + td = td.unsqueeze(0) + actions = actions.unsqueeze(0) + + # Reward calculation is expensive since we need to run decap simulation (not vectorizable) + reward = torch.stack( + [ + self._single_env_reward(td_single, action) + for td_single, action in zip(td, actions) + ] + ) + return reward + + @staticmethod + def check_solution_validity(td: TensorDict, actions: torch.Tensor): + assert True, "Not implemented" + + def _single_env_reward(self, td, actions): + """Get reward for single environment. We""" + + list_probe = torch.nonzero(td["probe"]).squeeze() + scores = torch.zeros_like(list_probe, dtype=torch.float32) + for i, probe in enumerate(list_probe): + # Get the decap scores for the probe location + scores[i] = self._decap_simulator(probe, actions) + # If minmax, return min of max decap scores else mean + return scores.min() if self.reward_type == "minmax" else scores.mean() diff --git a/rl4co/envs/eda/mdpp/generator.py b/rl4co/envs/eda/mdpp/generator.py new file mode 100644 index 00000000..75767150 --- /dev/null +++ b/rl4co/envs/eda/mdpp/generator.py @@ -0,0 +1,178 @@ +import os +import zipfile +from typing import Union, Callable + +import torch +import numpy as np + +from robust_downloader import download +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.data.utils import load_npz_to_tensordict +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class MDPPGenerator(Generator): + """Data generator for the Multi Decap Placement Problem (MDPP). + + Args: + min_loc: Minimum location value. Defaults to 0. + max_loc: Maximum location value. Defaults to 1. + num_keepout_min: Minimum number of keepout regions. Defaults to 1. + num_keepout_max: Maximum number of keepout regions. Defaults to 50. + max_decaps: Maximum number of decaps. Defaults to 20. + data_dir: Directory to store data. Defaults to "data/dpp/". + This can be downloaded from this [url](https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95). + chip_file: Name of the chip file. Defaults to "10x10_pkg_chip.npy". + decap_file: Name of the decap file. Defaults to "01nF_decap.npy". + freq_file: Name of the frequency file. Defaults to "freq_201.npy". + url: URL to download data from. Defaults to None. + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + depot [batch_size, 2]: location of the depot + demand [batch_size, num_loc]: demand of each customer + capacity [batch_size]: capacity of the vehicle + """ + def __init__( + self, + min_loc: float = 0.0, + max_loc: float = 1.0, + num_keepout_min: int = 1, + num_keepout_max: int = 50, + num_probes_min: int = 2, + num_probes_max: int = 5, + max_decaps: int = 20, + data_dir: str = "data/dpp/", + chip_file: str = "10x10_pkg_chip.npy", + decap_file: str = "01nF_decap.npy", + freq_file: str = "freq_201.npy", + url: str = None, + **unused_kwargs + ): + self.min_loc = min_loc + self.max_loc = max_loc + self.num_keepout_min = num_keepout_min + self.num_keepout_max = num_keepout_max + self.num_probes_min = num_probes_min + self.num_probes_max = num_probes_max + self.max_decaps = max_decaps + self.data_dir = data_dir + + # DPP environment doen't have any other kwargs + if len(unused_kwargs) > 0: + log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") + + + # Download and load the data from online dataset + self.url = ( + "https://github.com/kaist-silab/devformer/raw/main/data/data.zip" + if url is None + else url + ) + self.backup_url = ( + "https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95" + ) + self._load_dpp_data(chip_file, decap_file, freq_file) + + # Check the validity of the keepout parameters + assert ( + num_keepout_min <= num_keepout_max + ), "num_keepout_min must be <= num_keepout_max" + assert ( + num_keepout_max <= self.size**2 + ), "num_keepout_max must be <= size * size (total number of locations)" + + def _generate(self, batch_size) -> TensorDict: + m = n = self.size + # if int, convert to list and make it a batch for easier generation + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + batched = len(batch_size) > 0 + bs = [1] if not batched else batch_size + + # Create a list of locs on a grid + locs = torch.meshgrid(torch.arange(m), torch.arange(n)) + locs = torch.stack(locs, dim=-1).reshape(-1, 2) + # normalize the locations by the number of rows and columns + locs = locs / torch.tensor([m, n], dtype=torch.float) + locs = locs[None].expand(*bs, -1, -1) + + # Create available mask + available = torch.ones((*bs, m * n), dtype=torch.bool) + + # Sample probe location from m*n + probe = torch.randint(m * n, size=(*bs, 1)) + available.scatter_(1, probe, False) + + # Sample probe locatins + num_probe = torch.randint( + self.num_probes_min, + self.num_probes_max, + size=(*bs, 1), + ) + probe = [torch.randperm(m * n)[:p] for p in num_probe] + probes = torch.zeros((*bs, m * n), dtype=torch.bool) + for i, (a, p) in enumerate(zip(available, probe)): + available[i] = a.scatter(0, p, False) + probes[i] = probes[i].scatter(0, p, True) + + # Sample keepout locations from m*n except probe + num_keepout = torch.randint( + self.num_keepout_min, + self.num_keepout_max, + size=(*bs, 1), + ) + keepouts = [torch.randperm(m * n)[:k] for k in num_keepout] + for i, (a, k) in enumerate(zip(available, keepouts)): + available[i] = a.scatter(0, k, False) + + return TensorDict( + { + "locs": locs if batched else locs.squeeze(0), + "probe": probes if batched else probes.squeeze(0), + "action_mask": available if batched else available.squeeze(0), + }, + batch_size=batch_size, + ) + + def _load_dpp_data(self, chip_file, decap_file, freq_file): + def _load_file(fpath): + f = os.path.join(self.data_dir, fpath) + if not os.path.isfile(f): + self._download_data() + with open(f, "rb") as f_: + return torch.from_numpy(np.load(f_)) + + self.raw_pdn = _load_file(chip_file) # [num_freq, size^2, size^2] + self.decap = _load_file(decap_file).to(torch.complex64) # [num_freq, 1, 1] + self.freq = _load_file(freq_file) # [num_freq] + self.size = int(np.sqrt(self.raw_pdn.shape[-1])) + self.num_freq = self.freq.shape[0] + + def _download_data(self): + log.info("Downloading data...") + try: + download(self.url, self.data_dir, "data.zip") + except Exception: + log.error( + f"Download from main url {self.url} failed. Trying backup url {self.backup_url}..." + ) + download(self.backup_url, self.data_dir, "data.zip") + log.info("Download complete. Unzipping...") + zipfile.ZipFile(os.path.join(self.data_dir, "data.zip"), "r").extractall( + self.data_dir + ) + log.info("Unzip complete. Removing zip file") + os.remove(os.path.join(self.data_dir, "data.zip")) + + def load_data(self, fpath, batch_size=[]): + data = load_npz_to_tensordict(fpath) + # rename key if necessary (old dpp version) + if "observation" in data.keys(): + data["locs"] = data.pop("observation") + return data diff --git a/rl4co/envs/eda/mdpp/render.py b/rl4co/envs/eda/mdpp/render.py new file mode 100644 index 00000000..fbd4cd00 --- /dev/null +++ b/rl4co/envs/eda/mdpp/render.py @@ -0,0 +1,161 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import cm, colormaps + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(self, td, actions=None, ax=None, legend=True, settings=None): + """Plot a grid of squares representing the environment. + The keepout regions are the action_mask - decaps - probe + """ + + import matplotlib.pyplot as plt + + from matplotlib.lines import Line2D + from matplotlib.patches import Annulus, Rectangle, RegularPolygon + + if settings is None: + settings = { + "available": {"color": "white", "label": "available"}, + "keepout": {"color": "grey", "label": "keepout"}, + "probe": {"color": "tab:red", "label": "probe"}, + "decap": {"color": "tab:blue", "label": "decap"}, + } + + def draw_capacitor(ax, x, y, color="black"): + # Backgrund rectangle: same as color but with alpha=0.5 + ax.add_patch(Rectangle((x, y), 1, 1, color=color, alpha=0.5)) + + # Create the plates of the capacitor + plate_width, plate_height = ( + 0.3, + 0.1, + ) # Width and height switched to make vertical + plate_gap = 0.2 + plate1 = Rectangle( + (x + 0.5 - plate_width / 2, y + 0.5 - plate_height - plate_gap / 2), + plate_width, + plate_height, + color=color, + ) + plate2 = Rectangle( + (x + 0.5 - plate_width / 2, y + 0.5 + plate_gap / 2), + plate_width, + plate_height, + color=color, + ) + + # Add the plates to the axes + ax.add_patch(plate1) + ax.add_patch(plate2) + + # Add connection lines (wires) + line_length = 0.2 + line1 = Line2D( + [x + 0.5, x + 0.5], + [ + y + 0.5 - plate_height - plate_gap / 2 - line_length, + y + 0.5 - plate_height - plate_gap / 2, + ], + color=color, + ) + line2 = Line2D( + [x + 0.5, x + 0.5], + [ + y + 0.5 + plate_height + plate_gap / 2, + y + 0.5 + plate_height + plate_gap / 2 + line_length, + ], + color=color, + ) + + # Add the lines to the axes + ax.add_line(line1) + ax.add_line(line2) + + def draw_probe(ax, x, y, color="black"): + # Backgrund rectangle: same as color but with alpha=0.5 + ax.add_patch(Rectangle((x, y), 1, 1, color=color, alpha=0.5)) + ax.add_patch(Annulus((x + 0.5, y + 0.5), (0.2, 0.2), 0.1, color=color)) + + def draw_keepout(ax, x, y, color="black"): + # Backgrund rectangle: same as color but with alpha=0.5 + ax.add_patch(Rectangle((x, y), 1, 1, color=color, alpha=0.5)) + ax.add_patch( + RegularPolygon( + (x + 0.5, y + 0.5), numVertices=6, radius=0.45, color=color + ) + ) + + size = self.size + td = td.detach().cpu() + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + + if actions is None: + actions = td.get("action", None) + + # Transform actions from idx to one-hot + decaps = torch.zeros(size**2) + decaps.scatter_(0, actions, 1) + decaps = decaps.reshape(size, size) + + keepout = ~td["action_mask"].reshape(size, size) + probes = td["probe"].reshape(size, size) + + if ax is None: + _, ax = plt.subplots(1, 1, figsize=(6, 6)) + + grid = np.meshgrid(np.arange(0, size), np.arange(0, size)) + grid = np.stack(grid, axis=-1) + + xdim, ydim = grid.shape[0], grid.shape[1] + # ax.imshow(np.zeros((xdim, ydim)), cmap="gray") + + ax.set_xlim(0, xdim) + ax.set_ylim(0, ydim) + + for i in range(xdim): + for j in range(ydim): + x, y = grid[i, j, 0], grid[i, j, 1] + + if decaps[i, j] == 1: + draw_capacitor(ax, x, y, color=settings["decap"]["color"]) + elif probes[i, j] == 1: + draw_probe(ax, x, y, color=settings["probe"]["color"]) + elif keepout[i, j] == 1: + draw_keepout(ax, x, y, color=settings["keepout"]["color"]) + + ax.grid( + which="major", axis="both", linestyle="-", color="k", linewidth=1, alpha=0.5 + ) + # set 10 ticks + ax.set_xticks(np.arange(0, xdim, 1)) + ax.set_yticks(np.arange(0, ydim, 1)) + + # Invert y axis + ax.invert_yaxis() + + # # Add legend + if legend: + colors = [settings[k]["color"] for k in settings.keys()] + labels = [settings[k]["label"] for k in settings.keys()] + handles = [ + plt.Rectangle( + (0, 0), 1, 1, color=c, edgecolor="k", linestyle="-", linewidth=1 + ) + for c in colors + ] + ax.legend( + handles, + [label for label in labels], + ncol=len(colors), + loc="upper center", + bbox_to_anchor=(0.5, 1.1), + ) diff --git a/rl4co/envs/routing/__init__.py b/rl4co/envs/routing/__init__.py index 392f2640..6f59e546 100644 --- a/rl4co/envs/routing/__init__.py +++ b/rl4co/envs/routing/__init__.py @@ -1,12 +1,12 @@ -from rl4co.envs.routing.atsp import ATSPEnv -from rl4co.envs.routing.cvrp import CVRPEnv -from rl4co.envs.routing.cvrptw import CVRPTWEnv -from rl4co.envs.routing.mtsp import MTSPEnv -from rl4co.envs.routing.op import OPEnv -from rl4co.envs.routing.pctsp import PCTSPEnv -from rl4co.envs.routing.pdp import PDPEnv -from rl4co.envs.routing.sdvrp import SDVRPEnv -from rl4co.envs.routing.spctsp import SPCTSPEnv -from rl4co.envs.routing.svrp import SVRPEnv -from rl4co.envs.routing.tsp import TSPEnv -from rl4co.envs.routing.mdcpdp import MDCPDPEnv +from rl4co.envs.routing.atsp.env import ATSPEnv +from rl4co.envs.routing.cvrp.env import CVRPEnv +from rl4co.envs.routing.cvrptw.env import CVRPTWEnv +from rl4co.envs.routing.mtsp.env import MTSPEnv +from rl4co.envs.routing.op.env import OPEnv +from rl4co.envs.routing.pctsp.env import PCTSPEnv +from rl4co.envs.routing.pdp.env import PDPEnv +from rl4co.envs.routing.sdvrp.env import SDVRPEnv +from rl4co.envs.routing.spctsp.env import SPCTSPEnv +from rl4co.envs.routing.svrp.env import SVRPEnv +from rl4co.envs.routing.tsp.env import TSPEnv +from rl4co.envs.routing.mdcpdp.env import MDCPDPEnv diff --git a/rl4co/envs/routing/atsp.py b/rl4co/envs/routing/atsp/env.py similarity index 51% rename from rl4co/envs/routing/atsp.py rename to rl4co/envs/routing/atsp/env.py index 2e8fd8da..d02fc142 100644 --- a/rl4co/envs/routing/atsp.py +++ b/rl4co/envs/routing/atsp/env.py @@ -14,40 +14,52 @@ from rl4co.envs.common.utils import batch_to_scalar from rl4co.utils.pylogger import get_pylogger +from .generator import ATSPGenerator +from .render import render + log = get_pylogger(__name__) class ATSPEnv(RL4COEnvBase): - """ - Asymmetric Traveling Salesman Problem environment - At each step, the agent chooses a city to visit. The reward is 0 unless the agent visits all the cities. + """Asymmetric Traveling Salesman Problem (ATSP) environment + At each step, the agent chooses a customer to visit. The reward is 0 unless the agent visits all the customers. In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. Unlike the TSP, the distance matrix is asymmetric, i.e., the distance from A to B is not necessarily the same as the distance from B to A. + Observations: + - distance matrix between customers + - the current customer + - the first customer (for calculating the reward) + - the remaining unvisited customers + + Constraints: + - the tour starts and ends at the same customer. + - each customer must be visited exactly once. + + Finish Condition: + - the agent has visited all customers. + + Reward: + - (minus) the negative length of the path. + Args: - num_loc: number of locations (cities) in the TSP - td_params: parameters of the environment - seed: seed for the environment - device: device to use. Generally, no need to set as tensors are updated on the fly + generator: ATSPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "atsp" def __init__( self, - num_loc: int = 10, - min_dist: float = 0, - max_dist: float = 1, - tmat_class: bool = True, - td_params: TensorDict = None, + generator: ATSPGenerator = None, + generator_params: dict = {}, **kwargs, ): super().__init__(**kwargs) - self.num_loc = num_loc - self.min_dist = min_dist - self.max_dist = max_dist - self.tmat_class = tmat_class - self._make_spec(td_params) + if generator is None: + generator = ATSPGenerator(**generator_params) + self.generator = generator + self._make_spec(self.generator) @staticmethod def _step(td: TensorDict) -> TensorDict: @@ -79,24 +91,13 @@ def _step(td: TensorDict) -> TensorDict: def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: # Initialize distance matrix - cost_matrix = ( - td["cost_matrix"] if td is not None else None - ) # dm = distance matrix - if batch_size is None: - batch_size = ( - self.batch_size if cost_matrix is None else cost_matrix.shape[:-2] - ) - device = cost_matrix.device if cost_matrix is not None else self.device - self.to(device) - if cost_matrix is None: - cost_matrix = self.generate_data(batch_size=batch_size).to(device)[ - "cost_matrix" - ] + cost_matrix = td["cost_matrix"] + device = td.device # Other variables current_node = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) available = torch.ones( - (*batch_size, self.num_loc), dtype=torch.bool, device=device + (*batch_size, self.generator.num_loc), dtype=torch.bool, device=device ) # 1 means not visited, i.e. action is allowed i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) @@ -111,12 +112,12 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size=batch_size, ) - def _make_spec(self, td_params: TensorDict = None): + def _make_spec(self, generator: ATSPGenerator): self.observation_spec = CompositeSpec( cost_matrix=BoundedTensorSpec( - low=self.min_dist, - high=self.max_dist, - shape=(self.num_loc, self.num_loc), + low=generator.min_dist, + high=generator.max_dist, + shape=(generator.num_loc, generator.num_loc), dtype=torch.float32, ), first_node=UnboundedDiscreteTensorSpec( @@ -132,7 +133,7 @@ def _make_spec(self, td_params: TensorDict = None): dtype=torch.int64, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc), + shape=(generator.num_loc), dtype=torch.bool, ), shape=(), @@ -141,19 +142,13 @@ def _make_spec(self, td_params: TensorDict = None): shape=(1,), dtype=torch.int64, low=0, - high=self.num_loc, + high=generator.num_loc, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - def get_reward(self, td, actions) -> TensorDict: + def _get_reward(self, td, actions) -> TensorDict: distance_matrix = td["cost_matrix"] - assert ( - torch.arange(actions.size(1), out=actions.data.new()) - .view(1, -1) - .expand_as(actions) - == actions.data.sort(1)[0] - ).all(), "Invalid tour" # Get indexes of tour edges nodes_src = actions @@ -164,65 +159,15 @@ def get_reward(self, td, actions) -> TensorDict: # return negative tour length return -distance_matrix[batch_idx, nodes_src, nodes_tgt].sum(-1) - def generate_data(self, batch_size) -> TensorDict: - # Generate distance matrices inspired by the reference MatNet (Kwon et al., 2021) - # We satifsy the triangle inequality (TMAT class) in a batch - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - dms = ( - torch.rand((*batch_size, self.num_loc, self.num_loc), generator=self.rng) - * (self.max_dist - self.min_dist) - + self.min_dist - ) - dms[..., torch.arange(self.num_loc), torch.arange(self.num_loc)] = 0 - log.info("Using TMAT class (triangle inequality): {}".format(self.tmat_class)) - if self.tmat_class: - while True: - old_dms = dms.clone() - dms, _ = ( - dms[..., :, None, :] + dms[..., None, :, :].transpose(-2, -1) - ).min(dim=-1) - if (dms == old_dms).all(): - break - return TensorDict({"cost_matrix": dms}, batch_size=batch_size) + @staticmethod + def check_solution_validity(td: TensorDict, actions: torch.Tensor): + assert ( + torch.arange(actions.size(1), out=actions.data.new()) + .view(1, -1) + .expand_as(actions) + == actions.data.sort(1)[0] + ).all(), "Invalid tour" @staticmethod def render(td, actions=None, ax=None): - try: - import networkx as nx - except ImportError: - log.warn( - "Networkx is not installed. Please install it with `pip install networkx`" - ) - return - - td = td.detach().cpu() - if actions is None: - actions = td.get("action", None) - - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] - - src_nodes = actions - tgt_nodes = torch.roll(actions, 1, dims=0) - - # Plot with networkx - G = nx.DiGraph(td["cost_matrix"].numpy()) - pos = nx.spring_layout(G) - nx.draw( - G, - pos, - with_labels=True, - node_color="skyblue", - node_size=800, - edge_color="white", - ) - - # draw edges src_nodes -> tgt_nodes - edgelist = [ - (src_nodes[i].item(), tgt_nodes[i].item()) for i in range(len(src_nodes)) - ] - nx.draw_networkx_edges( - G, pos, edgelist=edgelist, width=2, alpha=1, edge_color="black" - ) + return render(td, actions, ax) diff --git a/rl4co/envs/routing/atsp/generator.py b/rl4co/envs/routing/atsp/generator.py new file mode 100644 index 00000000..89e381ca --- /dev/null +++ b/rl4co/envs/routing/atsp/generator.py @@ -0,0 +1,71 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class ATSPGenerator(Generator): + """Data generator for the Asymmetric Travelling Salesman Problem (ATSP) + Generate distance matrices inspired by the reference MatNet (Kwon et al., 2021) + We satifsy the triangle inequality (TMAT class) in a batch + + Args: + num_loc: number of locations (customers) in the TSP + min_dist: minimum value for the distance between nodes + max_dist: maximum value for the distance between nodes + dist_distribution: distribution for the distance between nodes + tmat_class: whether to generate a class of distance matrix + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + """ + def __init__( + self, + num_loc: int = 10, + min_dist: float = 0.0, + max_dist: float = 1.0, + dist_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + tmat_class: bool = True, + **kwargs + ): + self.num_loc = num_loc + self.min_dist = min_dist + self.max_dist = max_dist + self.tmat_class = tmat_class + + # Distance distribution + if kwargs.get("dist_sampler", None) is not None: + self.dist_sampler = kwargs["dist_sampler"] + else: + self.dist_sampler = get_sampler("dist", dist_distribution, 0.0, 1.0, **kwargs) + + def _generate(self, batch_size) -> TensorDict: + # Generate distance matrices inspired by the reference MatNet (Kwon et al., 2021) + # We satifsy the triangle inequality (TMAT class) in a batch + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + dms = ( + self.dist_sampler.sample((batch_size + [self.num_loc, self.num_loc])) + * (self.max_dist - self.min_dist) + + self.min_dist + ) + dms[..., torch.arange(self.num_loc), torch.arange(self.num_loc)] = 0 + log.info("Using TMAT class (triangle inequality): {}".format(self.tmat_class)) + if self.tmat_class: + while True: + old_dms = dms.clone() + dms, _ = ( + dms[..., :, None, :] + dms[..., None, :, :].transpose(-2, -1) + ).min(dim=-1) + if (dms == old_dms).all(): + break + return TensorDict({"cost_matrix": dms}, batch_size=batch_size) diff --git a/rl4co/envs/routing/atsp/render.py b/rl4co/envs/routing/atsp/render.py new file mode 100644 index 00000000..8ad0a903 --- /dev/null +++ b/rl4co/envs/routing/atsp/render.py @@ -0,0 +1,50 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td, actions=None, ax=None): + if ax is None: + # Create a plot of the nodes + _, ax = plt.subplots() + + td = td.detach().cpu() + + if actions is None: + actions = td.get("action", None) + + # If batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + locs = td["locs"] + + # Gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + actions = actions.detach().cpu() + locs = gather_by_index(locs, actions, dim=0) + + # Cat the first node to the end to complete the tour + locs = torch.cat((locs, locs[0:1])) + x, y = locs[:, 0], locs[:, 1] + + # Plot the visited nodes + ax.scatter(x, y, color="tab:blue") + + # Add arrows between visited nodes as a quiver plot + dx, dy = np.diff(x), np.diff(y) + ax.quiver( + x[:-1], y[:-1], dx, dy, scale_units="xy", angles="xy", scale=1, color="k" + ) + + # Setup limits and show + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/cvrp.py b/rl4co/envs/routing/cvrp.py deleted file mode 100644 index 4a28f7bc..00000000 --- a/rl4co/envs/routing/cvrp.py +++ /dev/null @@ -1,437 +0,0 @@ -from typing import Optional - -import torch - -from tensordict.tensordict import TensorDict -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) - -from rl4co.data.utils import load_npz_to_tensordict -from rl4co.envs.common.base import RL4COEnvBase -from rl4co.utils.ops import gather_by_index, get_tour_length -from rl4co.utils.pylogger import get_pylogger - -log = get_pylogger(__name__) - - -# From Kool et al. 2019, Hottung et al. 2022, Kim et al. 2023 -CAPACITIES = { - 10: 20.0, - 15: 25.0, - 20: 30.0, - 30: 33.0, - 40: 37.0, - 50: 40.0, - 60: 43.0, - 75: 45.0, - 100: 50.0, - 125: 55.0, - 150: 60.0, - 200: 70.0, - 500: 100.0, - 1000: 150.0, -} - - -class CVRPEnv(RL4COEnvBase): - """Capacitated Vehicle Routing Problem (CVRP) environment. - At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. - When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to - visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities. - In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. - - Args: - num_loc: number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) - min_loc: minimum value for the location coordinates - max_loc: maximum value for the location coordinates - min_demand: minimum value for the demand of each customer - max_demand: maximum value for the demand of each customer - vehicle_capacity: capacity of the vehicle - td_params: parameters of the environment - """ - - name = "cvrp" - - def __init__( - self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - min_demand: float = 1, - max_demand: float = 10, - vehicle_capacity: float = 1.0, - capacity: float = None, - td_params: TensorDict = None, - **kwargs, - ): - super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self.min_demand = min_demand - self.max_demand = max_demand - self.capacity = CAPACITIES.get(num_loc, None) if capacity is None else capacity - if self.capacity is None: - raise ValueError( - f"Capacity for {num_loc} locations is not defined. Please provide a capacity manually." - ) - self.vehicle_capacity = vehicle_capacity - self._make_spec(td_params) - - def _step(self, td: TensorDict) -> TensorDict: - current_node = td["action"][:, None] # Add dimension for step - n_loc = td["demand"].size(-1) # Excludes depot - - # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot! - selected_demand = gather_by_index( - td["demand"], torch.clamp(current_node - 1, 0, n_loc - 1), squeeze=False - ) - - # Increase capacity if depot is not visited, otherwise set to 0 - used_capacity = (td["used_capacity"] + selected_demand) * ( - current_node != 0 - ).float() - - # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot - # Add one dimension since we write a single value - visited = td["visited"].scatter(-1, current_node[..., None], 1) - - # SECTION: get done - done = visited.sum(-1) == visited.size(-1) - reward = torch.zeros_like(done) - - td.update( - { - "current_node": current_node, - "used_capacity": used_capacity, - "visited": visited, - "reward": reward, - "done": done, - } - ) - td.set("action_mask", self.get_action_mask(td)) - return td - - def _reset( - self, - td: Optional[TensorDict] = None, - batch_size: Optional[list] = None, - ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - self.to(td.device) - - # Create reset TensorDict - td_reset = TensorDict( - { - "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), - "demand": td["demand"], - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device - ), - "used_capacity": torch.zeros((*batch_size, 1), device=self.device), - "vehicle_capacity": torch.full( - (*batch_size, 1), self.vehicle_capacity, device=self.device - ), - "visited": torch.zeros( - (*batch_size, 1, td["locs"].shape[-2] + 1), - dtype=torch.uint8, - device=self.device, - ), - }, - batch_size=batch_size, - ) - td_reset.set("action_mask", self.get_action_mask(td_reset)) - return td_reset - - @staticmethod - def get_action_mask(td: TensorDict) -> torch.Tensor: - # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting - exceeds_cap = ( - td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] - ) - - # Nodes that cannot be visited are already visited or too much demand to be served now - mask_loc = td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap - - # Cannot visit the depot if just visited and still unserved nodes - mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0) - return ~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2) - - def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: - # Check that the solution is valid - if self.check_solution: - self.check_solution_validity(td, actions) - - # Gather dataset in order of tour - batch_size = td["locs"].shape[0] - depot = td["locs"][..., 0:1, :] - locs_ordered = torch.cat( - [ - depot, - gather_by_index(td["locs"], actions).reshape( - [batch_size, actions.size(-1), 2] - ), - ], - dim=1, - ) - return -get_tour_length(locs_ordered) - - @staticmethod - def check_solution_validity(td: TensorDict, actions: torch.Tensor): - """Check that solution is valid: nodes are not visited twice except depot and capacity is not exceeded""" - # Check if tour is valid, i.e. contain 0 to n-1 - batch_size, graph_size = td["demand"].size() - sorted_pi = actions.data.sort(1)[0] - - # Sorting it should give all zeros at front and then 1...n - assert ( - torch.arange(1, graph_size + 1, out=sorted_pi.data.new()) - .view(1, -1) - .expand(batch_size, graph_size) - == sorted_pi[:, -graph_size:] - ).all() and (sorted_pi[:, :-graph_size] == 0).all(), "Invalid tour" - - # Visiting depot resets capacity so we add demand = -capacity (we make sure it does not become negative) - demand_with_depot = torch.cat((-td["vehicle_capacity"], td["demand"]), 1) - d = demand_with_depot.gather(1, actions) - - used_cap = torch.zeros_like(td["demand"][:, 0]) - for i in range(actions.size(1)): - used_cap += d[ - :, i - ] # This will reset/make capacity negative if i == 0, e.g. depot visited - # Cannot use less than 0 - used_cap[used_cap < 0] = 0 - assert ( - used_cap <= td["vehicle_capacity"] + 1e-5 - ).all(), "Used more than capacity" - - def generate_data(self, batch_size) -> TensorDict: - # Batch size input check - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - # Initialize the locations (including the depot which is always the first node) - locs_with_depot = ( - torch.FloatTensor(*batch_size, self.num_loc + 1, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Initialize the demand for nodes except the depot - # Demand sampling Following Kool et al. (2019) - # Generates a slightly different distribution than using torch.randint - demand = ( - ( - torch.FloatTensor(*batch_size, self.num_loc) - .uniform_(self.min_demand - 1, self.max_demand - 1) - .int() - + 1 - ) - .float() - .to(self.device) - ) - - # Support for heterogeneous capacity if provided - if not isinstance(self.capacity, torch.Tensor): - capacity = torch.full((*batch_size,), self.capacity, device=self.device) - else: - capacity = self.capacity - - return TensorDict( - { - "locs": locs_with_depot[..., 1:, :], - "depot": locs_with_depot[..., 0, :], - "demand": demand / self.capacity, - "capacity": capacity, - }, - batch_size=batch_size, - device=self.device, - ) - - @staticmethod - def load_data(fpath, batch_size=[]): - """Dataset loading from file - Normalize demand by capacity to be in [0, 1] - """ - td_load = load_npz_to_tensordict(fpath) - td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None]) - return td_load - - def _make_spec(self, td_params: TensorDict): - """Make the observation and action specs from the parameters.""" - self.observation_spec = CompositeSpec( - locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc + 1, 2), - dtype=torch.float32, - ), - current_node=UnboundedDiscreteTensorSpec( - shape=(1), - dtype=torch.int64, - ), - demand=BoundedTensorSpec( - low=-self.capacity, - high=self.max_demand, - shape=(self.num_loc, 1), # demand is only for customers - dtype=torch.float32, - ), - action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1, 1), - dtype=torch.bool, - ), - shape=(), - ) - self.action_spec = BoundedTensorSpec( - shape=(1,), - dtype=torch.int64, - low=0, - high=self.num_loc + 1, - ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) - self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - - @staticmethod - def render( - td: TensorDict, - actions=None, - ax=None, - scale_xy: bool = True, - ): - import matplotlib.pyplot as plt - import numpy as np - - from matplotlib import cm, colormaps - - num_routine = (actions == 0).sum().item() + 2 - base = colormaps["nipy_spectral"] - color_list = base(np.linspace(0, 1, num_routine)) - cmap_name = base.name + str(num_routine) - out = base.from_list(cmap_name, color_list, num_routine) - - if ax is None: - # Create a plot of the nodes - _, ax = plt.subplots() - - td = td.detach().cpu() - - if actions is None: - actions = td.get("action", None) - - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] - - locs = td["locs"] - scale_demand = CAPACITIES.get(td["locs"].size(-2) - 1, 1) - demands = td["demand"] * scale_demand - - # add the depot at the first action and the end action - actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) - - # gather locs in order of action if available - if actions is None: - log.warning("No action in TensorDict, rendering unsorted locs") - else: - locs = locs - - # Cat the first node to the end to complete the tour - x, y = locs[:, 0], locs[:, 1] - - # plot depot - ax.scatter( - locs[0, 0], - locs[0, 1], - edgecolors=cm.Set2(2), - facecolors="none", - s=100, - linewidths=2, - marker="s", - alpha=1, - ) - - # plot visited nodes - ax.scatter( - x[1:], - y[1:], - edgecolors=cm.Set2(0), - facecolors="none", - s=50, - linewidths=2, - marker="o", - alpha=1, - ) - - # plot demand bars - for node_idx in range(1, len(locs)): - ax.add_patch( - plt.Rectangle( - (locs[node_idx, 0] - 0.005, locs[node_idx, 1] + 0.015), - 0.01, - demands[node_idx - 1] / (scale_demand * 10), - edgecolor=cm.Set2(0), - facecolor=cm.Set2(0), - fill=True, - ) - ) - - # text demand - for node_idx in range(1, len(locs)): - ax.text( - locs[node_idx, 0], - locs[node_idx, 1] - 0.025, - f"{demands[node_idx-1].item():.2f}", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(0), - ) - - # text depot - ax.text( - locs[0, 0], - locs[0, 1] - 0.025, - "Depot", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(2), - ) - - # plot actions - color_idx = 0 - for action_idx in range(len(actions) - 1): - if actions[action_idx] == 0: - color_idx += 1 - from_loc = locs[actions[action_idx]] - to_loc = locs[actions[action_idx + 1]] - ax.plot( - [from_loc[0], to_loc[0]], - [from_loc[1], to_loc[1]], - color=out(color_idx), - lw=1, - ) - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), - size=15, - annotation_clip=False, - ) - - # Setup limits and show - if scale_xy: - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) - plt.show() diff --git a/rl4co/envs/routing/cvrp/env.py b/rl4co/envs/routing/cvrp/env.py new file mode 100644 index 00000000..6f370e10 --- /dev/null +++ b/rl4co/envs/routing/cvrp/env.py @@ -0,0 +1,235 @@ +from typing import Optional + +import torch + +from tensordict.tensordict import TensorDict +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) + +from rl4co.data.utils import load_npz_to_tensordict +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.utils.ops import gather_by_index, get_tour_length +from rl4co.utils.pylogger import get_pylogger + +from .generator import CVRPGenerator +from .render import render + +log = get_pylogger(__name__) + + +class CVRPEnv(RL4COEnvBase): + """Capacitated Vehicle Routing Problem (CVRP) environment. + At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. + When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to + visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities. + In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. + + Observations: + - location of the depot. + - locations and demand of each customer. + - current location of the vehicle. + - the remaining customer of the vehicle, + + Constraints: + - the tour starts and ends at the depot. + - each customer must be visited exactly once. + - the vehicle cannot visit customers exceed the remaining capacity. + - the vehicle can return to the depot to refill the capacity. + + Finish Condition: + - the vehicle has visited all customers and returned to the depot. + + Reward: + - (minus) the negative length of the path. + + Args: + generator: CVRPGenerator instance as the data generator + generator_params: parameters for the generator + """ + + name = "cvrp" + + def __init__( + self, + generator: CVRPGenerator = None, + generator_params: dict = {}, + **kwargs, + ): + super().__init__(**kwargs) + if generator is None: + generator = CVRPGenerator(**generator_params) + self.generator = generator + self._make_spec(self.generator) + + def _step(self, td: TensorDict) -> TensorDict: + current_node = td["action"][:, None] # Add dimension for step + n_loc = td["demand"].size(-1) # Excludes depot + + # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot! + selected_demand = gather_by_index( + td["demand"], torch.clamp(current_node - 1, 0, n_loc - 1), squeeze=False + ) + + # Increase capacity if depot is not visited, otherwise set to 0 + used_capacity = (td["used_capacity"] + selected_demand) * ( + current_node != 0 + ).float() + + # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot + # Add one dimension since we write a single value + visited = td["visited"].scatter(-1, current_node[..., None], 1) + + # SECTION: get done + done = visited.sum(-1) == visited.size(-1) + reward = torch.zeros_like(done) + + td.update( + { + "current_node": current_node, + "used_capacity": used_capacity, + "visited": visited, + "reward": reward, + "done": done, + } + ) + td.set("action_mask", self.get_action_mask(td)) + return td + + def _reset( + self, + td: Optional[TensorDict] = None, + batch_size: Optional[list] = None, + ) -> TensorDict: + device = td.device + + # Create reset TensorDict + td_reset = TensorDict( + { + "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), + "demand": td["demand"], + "current_node": torch.zeros( + *batch_size, 1, dtype=torch.long, device=device + ), + "used_capacity": torch.zeros((*batch_size, 1), device=device), + "vehicle_capacity": torch.full( + (*batch_size, 1), self.generator.vehicle_capacity, device=device + ), + "visited": torch.zeros( + (*batch_size, 1, td["locs"].shape[-2] + 1), + dtype=torch.uint8, + device=device, + ), + }, + batch_size=batch_size, + ) + td_reset.set("action_mask", self.get_action_mask(td_reset)) + return td_reset + + @staticmethod + def get_action_mask(td: TensorDict) -> torch.Tensor: + # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting + exceeds_cap = ( + td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] + ) + + # Nodes that cannot be visited are already visited or too much demand to be served now + mask_loc = td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap + + # Cannot visit the depot if just visited and still unserved nodes + mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0) + return ~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2) + + def _get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: + # Gather dataset in order of tour + batch_size = td["locs"].shape[0] + depot = td["locs"][..., 0:1, :] + locs_ordered = torch.cat( + [ + depot, + gather_by_index(td["locs"], actions).reshape( + [batch_size, actions.size(-1), 2] + ), + ], + dim=1, + ) + return -get_tour_length(locs_ordered) + + @staticmethod + def check_solution_validity(td: TensorDict, actions: torch.Tensor): + """Check that solution is valid: nodes are not visited twice except depot and capacity is not exceeded""" + # Check if tour is valid, i.e. contain 0 to n-1 + batch_size, graph_size = td["demand"].size() + sorted_pi = actions.data.sort(1)[0] + + # Sorting it should give all zeros at front and then 1...n + assert ( + torch.arange(1, graph_size + 1, out=sorted_pi.data.new()) + .view(1, -1) + .expand(batch_size, graph_size) + == sorted_pi[:, -graph_size:] + ).all() and (sorted_pi[:, :-graph_size] == 0).all(), "Invalid tour" + + # Visiting depot resets capacity so we add demand = -capacity (we make sure it does not become negative) + demand_with_depot = torch.cat((-td["vehicle_capacity"], td["demand"]), 1) + d = demand_with_depot.gather(1, actions) + + used_cap = torch.zeros_like(td["demand"][:, 0]) + for i in range(actions.size(1)): + used_cap += d[ + :, i + ] # This will reset/make capacity negative if i == 0, e.g. depot visited + # Cannot use less than 0 + used_cap[used_cap < 0] = 0 + assert ( + used_cap <= td["vehicle_capacity"] + 1e-5 + ).all(), "Used more than capacity" + + @staticmethod + def load_data(fpath, batch_size=[]): + """Dataset loading from file + Normalize demand by capacity to be in [0, 1] + """ + td_load = load_npz_to_tensordict(fpath) + td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None]) + return td_load + + def _make_spec(self, generator: CVRPGenerator): + self.observation_spec = CompositeSpec( + locs=BoundedTensorSpec( + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.num_loc + 1, 2), + dtype=torch.float32, + ), + current_node=UnboundedDiscreteTensorSpec( + shape=(1), + dtype=torch.int64, + ), + demand=BoundedTensorSpec( + low=-generator.capacity, + high=generator.max_demand, + shape=(generator.num_loc + 1, 1), + dtype=torch.float32, + ), + action_mask=UnboundedDiscreteTensorSpec( + shape=(generator.num_loc + 1, 1), + dtype=torch.bool, + ), + shape=(), + ) + self.action_spec = BoundedTensorSpec( + shape=(1,), + dtype=torch.int64, + low=0, + high=generator.num_loc + 1, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) + + @staticmethod + def render(td: TensorDict, actions: torch.Tensor=None, ax = None): + return render(td, actions, ax) diff --git a/rl4co/envs/routing/cvrp/generator.py b/rl4co/envs/routing/cvrp/generator.py new file mode 100644 index 00000000..28e01157 --- /dev/null +++ b/rl4co/envs/routing/cvrp/generator.py @@ -0,0 +1,131 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +# From Kool et al. 2019, Hottung et al. 2022, Kim et al. 2023 +CAPACITIES = { + 10: 20.0, + 15: 25.0, + 20: 30.0, + 30: 33.0, + 40: 37.0, + 50: 40.0, + 60: 43.0, + 75: 45.0, + 100: 50.0, + 125: 55.0, + 150: 60.0, + 200: 70.0, + 500: 100.0, + 1000: 150.0, +} + + +class CVRPGenerator(Generator): + """Data generator for the Capacitated Vehicle Routing Problem (CVRP). + Args: + num_loc: number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + loc_distribution: distribution for the location coordinates + min_demand: minimum value for the demand of each customer + max_demand: maximum value for the demand of each customer + demand_distribution: distribution for the demand of each customer + capacity: capacity of the vehicle + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + depot [batch_size, 2]: location of the depot + demand [batch_size, num_loc]: demand of each customer + capacity [batch_size]: capacity of the vehicle + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 1.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + depot_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + min_demand: int = 1, + max_demand: int = 10, + demand_distribution: Union[ + int, float, type, Callable + ] = Uniform, + vehicle_capacity: float = 1.0, + capacity: float = None, + **kwargs + ): + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.min_demand = min_demand + self.max_demand = max_demand + self.vehicle_capacity = vehicle_capacity + + # Location distribution + if kwargs.get("loc_sampler", None) is not None: + self.loc_sampler = kwargs["loc_sampler"] + else: + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) + + # Depot distribution + if kwargs.get("depot_sampler", None) is not None: + self.depot_sampler = kwargs["depot_sampler"] + else: + self.depot_sampler = get_sampler("depot", depot_distribution, min_loc, max_loc, **kwargs) + + # Demand distribution + if kwargs.get("demand_sampler", None) is not None: + self.demand_sampler = kwargs["demand_sampler"] + else: + self.demand_sampler = get_sampler("demand", demand_distribution, min_demand-1, max_demand-1, **kwargs) + + # Capacity + self.capacity = kwargs.get("capacity", None) + if self.capacity is None: # If not provided, use the default capacity from Kool et al. 2019 + self.capacity = CAPACITIES.get(num_loc, None) + if self.capacity is None: # If not in the table keys, find the closest number of nodes as the key + closest_num_loc = min(CAPACITIES.keys(), key=lambda x: abs(x - num_loc)) + self.capacity = CAPACITIES[closest_num_loc] + log.warning( + f"The capacity capacity for {num_loc} locations is not defined. Using the closest capacity: {self.capacity}\ + with {closest_num_loc} locations." + ) + + def _generate(self, batch_size) -> TensorDict: + # Sample locations + locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2)) + + # Sample depot + depot = self.depot_sampler.sample((*batch_size, 2)) + + # Sample demands + demand = self.demand_sampler.sample((*batch_size, self.num_loc)) + demand = (demand.int() + 1).float() + + # Sample capacities + capacity = torch.full((*batch_size, 1), self.capacity) + + return TensorDict( + { + "locs": locs, + "depot": depot, + "demand": demand / self.capacity, + "capacity": capacity, + }, + batch_size=batch_size, + ) diff --git a/rl4co/envs/routing/cvrp/render.py b/rl4co/envs/routing/cvrp/render.py new file mode 100644 index 00000000..74748e1c --- /dev/null +++ b/rl4co/envs/routing/cvrp/render.py @@ -0,0 +1,133 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import cm, colormaps + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td, actions=None, ax=None): + num_routine = (actions == 0).sum().item() + 2 + base = colormaps["nipy_spectral"] + color_list = base(np.linspace(0, 1, num_routine)) + cmap_name = base.name + str(num_routine) + out = base.from_list(cmap_name, color_list, num_routine) + + if ax is None: + # Create a plot of the nodes + _, ax = plt.subplots() + + td = td.detach().cpu() + + if actions is None: + actions = td.get("action", None) + + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + locs = td["locs"] + scale_demand = td["capacity"][0] + demands = td["demand"] * scale_demand + + # add the depot at the first action and the end action + actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) + + # gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + locs = locs + + # Cat the first node to the end to complete the tour + x, y = locs[:, 0], locs[:, 1] + + # plot depot + ax.scatter( + locs[0, 0], + locs[0, 1], + edgecolors=cm.Set2(2), + facecolors="none", + s=100, + linewidths=2, + marker="s", + alpha=1, + ) + + # plot visited nodes + ax.scatter( + x[1:], + y[1:], + edgecolors=cm.Set2(0), + facecolors="none", + s=50, + linewidths=2, + marker="o", + alpha=1, + ) + + # plot demand bars + for node_idx in range(1, len(locs)): + ax.add_patch( + plt.Rectangle( + (locs[node_idx, 0] - 0.005, locs[node_idx, 1] + 0.015), + 0.01, + demands[node_idx - 1] / (scale_demand * 10), + edgecolor=cm.Set2(0), + facecolor=cm.Set2(0), + fill=True, + ) + ) + + # text demand + for node_idx in range(1, len(locs)): + ax.text( + locs[node_idx, 0], + locs[node_idx, 1] - 0.025, + f"{demands[node_idx-1].item():.2f}", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(0), + ) + + # text depot + ax.text( + locs[0, 0], + locs[0, 1] - 0.025, + "Depot", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(2), + ) + + # plot actions + color_idx = 0 + for action_idx in range(len(actions) - 1): + if actions[action_idx] == 0: + color_idx += 1 + from_loc = locs[actions[action_idx]] + to_loc = locs[actions[action_idx + 1]] + ax.plot( + [from_loc[0], to_loc[0]], + [from_loc[1], to_loc[1]], + color=out(color_idx), + lw=1, + ) + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), + size=15, + annotation_clip=False, + ) + + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/cvrptw.py b/rl4co/envs/routing/cvrptw/env.py similarity index 55% rename from rl4co/envs/routing/cvrptw.py rename to rl4co/envs/routing/cvrptw/env.py index 0d0f43f9..8a8862f4 100644 --- a/rl4co/envs/routing/cvrptw.py +++ b/rl4co/envs/routing/cvrptw/env.py @@ -8,7 +8,8 @@ UnboundedContinuousTensorSpec, ) -from rl4co.envs.routing.cvrp import CVRPEnv, CAPACITIES +from rl4co.envs.routing.cvrp.env import CVRPEnv +from rl4co.envs.routing.cvrp.generator import CAPACITIES from rl4co.utils.ops import gather_by_index, get_distance from rl4co.data.utils import ( load_npz_to_tensordict, @@ -16,161 +17,92 @@ load_solomon_solution, ) +from ..cvrp.generator import CVRPGenerator +from .generator import CVRPTWGenerator +from .render import render + class CVRPTWEnv(CVRPEnv): """Capacitated Vehicle Routing Problem with Time Windows (CVRPTW) environment. - Inherits from the CVRPEnv class in which capacities are considered. + Inherits from the CVRPEnv class in which customers are considered. Additionally considers time windows within which a service has to be started. + Observations: + - location of the depot. + - locations and demand of each customer. + - current location of the vehicle. + - the remaining customer of the vehicle. + - the current time. + - service durations of each location. + - time windows of each location. + + Constraints: + - the tour starts and ends at the depot. + - each customer must be visited exactly once. + - the vehicle cannot visit customers exceed the remaining customer. + - the vehicle can return to the depot to refill the customer. + - the vehicle must start the service within the time window of each location. + + Finish Condition: + - the vehicle has visited all customers and returned to the depot. + + Reward: + - (minus) the negative length of the path. + Args: - num_loc (int): number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) - min_loc (float): minimum value for the location coordinates - max_loc (float): maximum value for the location coordinates. Defaults to 150. - min_demand (float): minimum value for the demand of each customer - max_demand (float): maximum value for the demand of each customer - max_time (int): maximum time for the environment. Defaults to 480. - vehicle_capacity (float): capacity of the vehicle - capacity (float): capacity of the vehicle - scale (bool): if True, the time windows and service durations are scaled to [0, 1]. Defaults to False. - td_params: parameters of the environment + generator: CVRPTWGenerator instance as the data generator + generator_params: parameters for the generator """ name = "cvrptw" def __init__( self, - max_loc: float = 150, # different default value to CVRPEnv to match max_time, will be scaled - max_time: int = 480, - scale: bool = False, + generator: CVRPTWGenerator = None, + generator_params: dict = {}, **kwargs, ): - self.min_time = 0 # always 0 - self.max_time = max_time - self.scale = scale - super().__init__(max_loc=max_loc, **kwargs) - - def _make_spec(self, td_params: TensorDict): - super()._make_spec(td_params) - - current_time = UnboundedContinuousTensorSpec( - shape=(1), dtype=torch.float32, device=self.device - ) - - current_loc = UnboundedContinuousTensorSpec( - shape=(2), dtype=torch.float32, device=self.device - ) - - durations = BoundedTensorSpec( - low=self.min_time, - high=self.max_time, - shape=(self.num_loc, 1), - dtype=torch.int64, - device=self.device, - ) - - time_windows = BoundedTensorSpec( - low=self.min_time, - high=self.max_time, - shape=( - self.num_loc, - 2, - ), # each location has a 2D time window (start, end) - dtype=torch.int64, - device=self.device, - ) - - # extend observation specs - self.observation_spec = CompositeSpec( - **self.observation_spec, - current_time=current_time, - current_loc=current_loc, - durations=durations, - time_windows=time_windows, - # vehicle_idx=vehicle_idx, - ) - - def generate_data(self, batch_size) -> TensorDict: - """ - Generates time windows and service durations for the locations. The depot has a time window of [0, self.max_time]. - The time windows define the time span within which a service has to be started. To reach the depot in time from the last node, - the end time of each node is bounded by the service duration and the distance back to the depot. - The start times of the time windows are bounded by how long it takes to travel there from the depot. - """ - td = super().generate_data(batch_size) - - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + super().__init__(**kwargs) + if generator is None: + generator = CVRPTWGenerator(**generator_params) + self.generator = generator + self._make_spec(self.generator) - ## define service durations - # generate randomly (first assume service durations of 0, to be changed later) - durations = torch.zeros( - *batch_size, self.num_loc + 1, dtype=torch.float32, device=self.device - ) - - ## define time windows - # 1. get distances from depot - dist = get_distance(td["depot"], td["locs"].transpose(0, 1)).transpose(0, 1) - dist = torch.cat((torch.zeros(*batch_size, 1, device=self.device), dist), dim=1) - # 2. define upper bound for time windows to make sure the vehicle can get back to the depot in time - upper_bound = self.max_time - dist - durations - # 3. create random values between 0 and 1 - ts_1 = torch.rand(*batch_size, self.num_loc + 1, device=self.device) - ts_2 = torch.rand(*batch_size, self.num_loc + 1, device=self.device) - # 4. scale values to lie between their respective min_time and max_time and convert to integer values - min_ts = (dist + (upper_bound - dist) * ts_1).int() - max_ts = (dist + (upper_bound - dist) * ts_2).int() - # 5. set the lower value to min, the higher to max - min_times = torch.min(min_ts, max_ts) - max_times = torch.max(min_ts, max_ts) - # 6. reset times for depot - min_times[..., :, 0] = 0.0 - max_times[..., :, 0] = self.max_time - - # 7. ensure min_times < max_times to prevent numerical errors in attention.py - # min_times == max_times may lead to nan values in _inner_mha() - mask = min_times == max_times - if torch.any(mask): - min_tmp = min_times.clone() - min_tmp[mask] = torch.max( - dist[mask].int(), min_tmp[mask] - 1 - ) # we are handling integer values, so we can simply substract 1 - min_times = min_tmp - - mask = min_times == max_times # update mask to new min_times - if torch.any(mask): - max_tmp = max_times.clone() - max_tmp[mask] = torch.min( - torch.floor(upper_bound[mask]).int(), - torch.max( - torch.ceil(min_tmp[mask] + durations[mask]).int(), - max_tmp[mask] + 1, - ), - ) - max_times = max_tmp - - # scale to [0, 1] - if self.scale: - durations = durations / self.max_time - min_times = min_times / self.max_time - max_times = max_times / self.max_time - td["depot"] = td["depot"] / self.max_time - td["locs"] = td["locs"] / self.max_time - - # 8. stack to tensor time_windows - time_windows = torch.stack((min_times, max_times), dim=-1) - - assert torch.all( - min_times < max_times - ), "Please make sure the relation between max_loc and max_time allows for feasible solutions." - - # reset duration at depot to 0 - durations[:, 0] = 0.0 - td.update( - { - "durations": durations, - "time_windows": time_windows, - } - ) - return td + def _make_spec(self, generator: CVRPTWGenerator): + if isinstance(generator, CVRPGenerator): + super()._make_spec(generator) + else: + current_time = UnboundedContinuousTensorSpec( + shape=(1), dtype=torch.float32, device=self.device + ) + current_loc = UnboundedContinuousTensorSpec( + shape=(2), dtype=torch.float32, device=self.device + ) + durations = BoundedTensorSpec( + low=generator.min_time, + high=generator.max_time, + shape=(generator.num_loc, 1), + dtype=torch.int64, + device=self.device, + ) + time_windows = BoundedTensorSpec( + low=generator.min_time, + high=generator.max_time, + shape=( + generator.num_loc, + 2, + ), # Each location has a 2D time window (start, end) + dtype=torch.int64, + device=self.device, + ) + # Extend observation specs + self.observation_spec = CompositeSpec( + **self.observation_spec, + current_time=current_time, + current_loc=current_loc, + durations=durations, + time_windows=time_windows, + ) @staticmethod def get_action_mask(td: TensorDict) -> torch.Tensor: @@ -211,32 +143,25 @@ def _step(self, td: TensorDict) -> TensorDict: def _reset( self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - self.to(td.device) - # Create reset TensorDict + device = td.device td_reset = TensorDict( { "locs": torch.cat((td["depot"][..., None, :], td["locs"]), -2), "demand": td["demand"], "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device + *batch_size, 1, dtype=torch.long, device=device ), "current_time": torch.zeros( - *batch_size, 1, dtype=torch.float32, device=self.device + *batch_size, 1, dtype=torch.float32, device=device ), - "used_capacity": torch.zeros((*batch_size, 1), device=self.device), + "used_capacity": torch.zeros((*batch_size, 1), device=device), "vehicle_capacity": torch.full( - (*batch_size, 1), self.vehicle_capacity, device=self.device + (*batch_size, 1), self.generator.vehicle_capacity, device=device ), "visited": torch.zeros( (*batch_size, 1, td["locs"].shape[-2] + 1), dtype=torch.uint8, - device=self.device, + device=device, ), "durations": td["durations"], "time_windows": td["time_windows"], @@ -246,10 +171,10 @@ def _reset( td_reset.set("action_mask", self.get_action_mask(td_reset)) return td_reset - def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: + def _get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: """The reward is the negative tour length. Time windows are not considered for the calculation of the reward.""" - return super().get_reward(td, actions) + return super()._get_reward(td, actions) @staticmethod def check_solution_validity(td: TensorDict, actions: torch.Tensor): @@ -300,8 +225,8 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): curr_time[curr_node == 0] = 0.0 # reset time for depot @staticmethod - def render(td: TensorDict, actions=None, ax=None, scale_xy: bool = False, **kwargs): - CVRPEnv.render(td=td, actions=actions, ax=ax, scale_xy=scale_xy, **kwargs) + def render(td: TensorDict, actions: torch.Tensor=None, ax = None): + render(td, actions, ax) @staticmethod def load_data( diff --git a/rl4co/envs/routing/cvrptw/generator.py b/rl4co/envs/routing/cvrptw/generator.py new file mode 100644 index 00000000..90d8218b --- /dev/null +++ b/rl4co/envs/routing/cvrptw/generator.py @@ -0,0 +1,161 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.envs.routing.cvrp.generator import CVRPGenerator +from rl4co.utils.ops import get_distance + + +class CVRPTWGenerator(CVRPGenerator): + """Data generator for the Capacitated Vehicle Routing Problem with Time Windows (CVRPTW) environment + Generates time windows and service durations for the locations. The depot has a time window of [0, self.max_time]. + The time windows define the time span within which a service has to be started. To reach the depot in time from the last node, + the end time of each node is bounded by the service duration and the distance back to the depot. + The start times of the time windows are bounded by how long it takes to travel there from the depot. + + Args: + num_loc: number of locations (customers) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates, default is 150 insted of 1.0, will be scaled + loc_distribution: distribution for the location coordinates + min_demand: minimum value for the demand of each customer + max_demand: maximum value for the demand of each customer + demand_distribution: distribution for the demand of each customer + capacity: capacity of the vehicle + max_time: maximum time for the vehicle to complete the tour + scale: if True, the locations, time windows, and service durations will be scaled to [0, 1]. Default to False + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each city + depot [batch_size, 2]: location of the depot + demand [batch_size, num_loc]: demand of each customer + while the demand of the depot is a placeholder + capacity [batch_size, 1]: capacity of the vehicle + durations [batch_size, num_loc]: service durations of each location + time_windows [batch_size, num_loc, 2]: time windows of each location + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 150.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + depot_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + min_demand: int = 1, + max_demand: int = 10, + demand_distribution: Union[ + int, float, type, Callable + ] = Uniform, + vehicle_capacity: float = 1.0, + capacity: float = None, + max_time: float = 480, + scale: bool = False, + **kwargs, + ): + super().__init__( + num_loc=num_loc, + min_loc=min_loc, + max_loc=max_loc, + loc_distribution=loc_distribution, + depot_distribution=depot_distribution, + min_demand=min_demand, + max_demand=max_demand, + demand_distribution=demand_distribution, + vehicle_capacity=vehicle_capacity, + capacity=capacity, + **kwargs, + ) + self.max_loc = max_loc + self.min_time = 0.0 + self.max_time = max_time + self.scale = scale + + def _generate(self, batch_size) -> TensorDict: + td = super()._generate(batch_size) + + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + ## define service durations + # generate randomly (first assume service durations of 0, to be changed later) + durations = torch.zeros( + *batch_size, self.num_loc + 1, dtype=torch.float32 + ) + + ## define time windows + # 1. get distances from depot + dist = get_distance(td["depot"], td["locs"].transpose(0, 1)).transpose(0, 1) + dist = torch.cat((torch.zeros(*batch_size, 1), dist), dim=1) + + # 2. define upper bound for time windows to make sure the vehicle can get back to the depot in time + upper_bound = self.max_time - dist - durations + + # 3. create random values between 0 and 1 + ts_1 = torch.rand(*batch_size, self.num_loc + 1) + ts_2 = torch.rand(*batch_size, self.num_loc + 1) + + # 4. scale values to lie between their respective min_time and max_time and convert to integer values + min_ts = (dist + (upper_bound - dist) * ts_1).int() + max_ts = (dist + (upper_bound - dist) * ts_2).int() + + # 5. set the lower value to min, the higher to max + min_times = torch.min(min_ts, max_ts) + max_times = torch.max(min_ts, max_ts) + + # 6. reset times for depot + min_times[..., :, 0] = 0.0 + max_times[..., :, 0] = self.max_time + + # 7. ensure min_times < max_times to prevent numerical errors in attention.py + # min_times == max_times may lead to nan values in _inner_mha() + mask = min_times == max_times + if torch.any(mask): + min_tmp = min_times.clone() + min_tmp[mask] = torch.max( + dist[mask].int(), min_tmp[mask] - 1 + ) # we are handling integer values, so we can simply substract 1 + min_times = min_tmp + + mask = min_times == max_times # update mask to new min_times + if torch.any(mask): + max_tmp = max_times.clone() + max_tmp[mask] = torch.min( + torch.floor(upper_bound[mask]).int(), + torch.max( + torch.ceil(min_tmp[mask] + durations[mask]).int(), + max_tmp[mask] + 1, + ), + ) + max_times = max_tmp + + # Scale to [0, 1] + if self.scale: + durations = durations / self.max_time + min_times = min_times / self.max_time + max_times = max_times / self.max_time + td["depot"] = td["depot"] / self.max_time + td["locs"] = td["locs"] / self.max_time + + # 8. stack to tensor time_windows + time_windows = torch.stack((min_times, max_times), dim=-1) + + assert torch.all( + min_times < max_times + ), "Please make sure the relation between max_loc and max_time allows for feasible solutions." + + # Reset duration at depot to 0 + durations[:, 0] = 0.0 + td.update( + { + "durations": durations, + "time_windows": time_windows, + } + ) + return td diff --git a/rl4co/envs/routing/cvrptw/render.py b/rl4co/envs/routing/cvrptw/render.py new file mode 100644 index 00000000..74748e1c --- /dev/null +++ b/rl4co/envs/routing/cvrptw/render.py @@ -0,0 +1,133 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import cm, colormaps + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td, actions=None, ax=None): + num_routine = (actions == 0).sum().item() + 2 + base = colormaps["nipy_spectral"] + color_list = base(np.linspace(0, 1, num_routine)) + cmap_name = base.name + str(num_routine) + out = base.from_list(cmap_name, color_list, num_routine) + + if ax is None: + # Create a plot of the nodes + _, ax = plt.subplots() + + td = td.detach().cpu() + + if actions is None: + actions = td.get("action", None) + + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + locs = td["locs"] + scale_demand = td["capacity"][0] + demands = td["demand"] * scale_demand + + # add the depot at the first action and the end action + actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) + + # gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + locs = locs + + # Cat the first node to the end to complete the tour + x, y = locs[:, 0], locs[:, 1] + + # plot depot + ax.scatter( + locs[0, 0], + locs[0, 1], + edgecolors=cm.Set2(2), + facecolors="none", + s=100, + linewidths=2, + marker="s", + alpha=1, + ) + + # plot visited nodes + ax.scatter( + x[1:], + y[1:], + edgecolors=cm.Set2(0), + facecolors="none", + s=50, + linewidths=2, + marker="o", + alpha=1, + ) + + # plot demand bars + for node_idx in range(1, len(locs)): + ax.add_patch( + plt.Rectangle( + (locs[node_idx, 0] - 0.005, locs[node_idx, 1] + 0.015), + 0.01, + demands[node_idx - 1] / (scale_demand * 10), + edgecolor=cm.Set2(0), + facecolor=cm.Set2(0), + fill=True, + ) + ) + + # text demand + for node_idx in range(1, len(locs)): + ax.text( + locs[node_idx, 0], + locs[node_idx, 1] - 0.025, + f"{demands[node_idx-1].item():.2f}", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(0), + ) + + # text depot + ax.text( + locs[0, 0], + locs[0, 1] - 0.025, + "Depot", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(2), + ) + + # plot actions + color_idx = 0 + for action_idx in range(len(actions) - 1): + if actions[action_idx] == 0: + color_idx += 1 + from_loc = locs[actions[action_idx]] + to_loc = locs[actions[action_idx + 1]] + ax.plot( + [from_loc[0], to_loc[0]], + [from_loc[1], to_loc[1]], + color=out(color_idx), + lw=1, + ) + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), + size=15, + annotation_clip=False, + ) + + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/mdcpdp.py b/rl4co/envs/routing/mdcpdp/env.py similarity index 58% rename from rl4co/envs/routing/mdcpdp.py rename to rl4co/envs/routing/mdcpdp/env.py index 4f0220b5..1f2f5dec 100644 --- a/rl4co/envs/routing/mdcpdp.py +++ b/rl4co/envs/routing/mdcpdp/env.py @@ -13,6 +13,9 @@ from rl4co.envs.common.base import RL4COEnvBase from rl4co.utils.ops import gather_by_index, get_tour_length +from .generator import MDCPDPGenerator +from .render import render + class MDCPDPEnv(RL4COEnvBase): """Multi Depot Capacitated Pickup and Delivery Problem (MDCPDP) environment. @@ -25,66 +28,65 @@ class MDCPDPEnv(RL4COEnvBase): The goal is to visit all the pickup and delivery locations in the shortest path possible starting from the depot The conditions is that the agent must visit a pickup location before visiting its corresponding delivery location The capacity is the maximum number of pickups that the vehicle can carry at the same time + + Observations: + - locs: locations of the cities [num_loc + num_depot, 2] + - current_node: current node of the agent [1] + - to_deliver: if the node is to deliver [1] + - i: current step [1] + - action_mask: mask of the available actions [num_loc + num_depot] + - shape: shape of the observation + + Constraints: + - The agent cannot visit the same city twice + - The agent must visit the pickup location before the delivery location + - The agent must visit the depot at the end of the tour + + Finish Condition: + - The agent visited all the locations + + Reward: + - Min-sum: the reward is the negative of the length of the tour + - Min-max: the reward is the negative of the maximum length of the tour + - Lateness: the reward is the negative of the cumulate sum of the length of the tour + - Lateness-square: the reward is the negative of the cumulate sum of the square of the length of the tour + Args: - num_loc: number of locations (cities) in the TSP - num_depot: number of depots, each depot has one vehicle - min_loc: minimum value of the location - max_loc: maximum value of the location - min_capacity: minimum value of the capacity - max_capacity: maximum value of the capacity - min_lateness_weight: minimum value of the lateness weight - max_lateness_weight: maximum value of the lateness weight + generator: MDCPDPGenerator instance as the data generator + generator_params: parameters for the generator dist_mode: distance mode. One of ["L1", "L2"] reward_mode: objective of the problem. One of ["lateness", "lateness_square", "minmax", "minsum"] problem_mode: type of the problem. One of ["close", "open"] start_mode: type of the start. One of ["order", "random"] - depot_mode: type of the depot. One of ["single", "multiple"], are all depots the same place - td_params: parameters of the environment - seed: seed for the environment - device: device to use. Generally, no need to set as tensors are updated on the fly """ name = "mdcpdp" def __init__( self, - num_loc: int = 20, - num_depot: int = 5, - min_loc: float = 0, - max_loc: float = 1, - min_capacity: int = 1, - max_capacity: int = 5, - min_lateness_weight: float = 1.0, - max_lateness_weight: float = 1.0, + generator: MDCPDPGenerator = None, + generator_params: dict = {}, dist_mode: str = "L2", reward_mode: str = "lateness", problem_mode: str = "close", start_mode: str = "order", - depot_mode: str = "multiple", - td_params: TensorDict = None, **kwargs, ): super().__init__(**kwargs) - self.num_loc = num_loc - self.num_depot = num_depot - self.min_loc = min_loc - self.max_loc = max_loc - self.min_capacity = min_capacity - self.max_capacity = max_capacity - self.min_lateness_weight = min_lateness_weight - self.max_lateness_weight = max_lateness_weight + if generator is None: + generator = MDCPDPGenerator(**generator_params) + self.generator = generator self.dist_mode = dist_mode self.reward_mode = reward_mode self.problem_mode = problem_mode self.start_mode = start_mode - self.depot_mode = depot_mode - self._make_spec(td_params) + self.depot_mode = generator.depot_mode + self._make_spec(self.generator) assert self.dist_mode in ["L1", "L2"], "Distance mode (L1/L2) not supported" assert self.reward_mode in ["lateness", "lateness_square", "minmax", "minsum"], "Objective mode not supported" assert self.problem_mode in ["close", "open"], "Task type (open/close) not supported" assert self.start_mode in ["order", "random"], "Start type (order/random) not supported" - assert self.depot_mode in ["single", "multiple"], "Depot type (single/multiple) not supported" def _step(self, td: TensorDict) -> TensorDict: current_node = td["action"].unsqueeze(-1) @@ -192,30 +194,23 @@ def _step(self, td: TensorDict) -> TensorDict: return td def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td.batch_size - - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - - self.to(td.device) - + device = td.device locs = torch.cat((td["depot"], td["locs"]), -2) # Record how many depots are visited - depot_idx = torch.zeros((*batch_size, 1), dtype=torch.int64, device=self.device) + depot_idx = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) # Pick is 1, deliver is 0 [batch_size, graph_size+1], i.e. [1, 1, ..., 1, 0, ...0] to_deliver = torch.cat( [ torch.ones( *batch_size, - self.num_loc // 2 + self.num_depot, + self.generator.num_loc // 2 + self.generator.num_depot, dtype=torch.bool, - device=self.device, + device=device, ), torch.zeros( - *batch_size, self.num_loc // 2, dtype=torch.bool, device=self.device + *batch_size, self.generator.num_loc // 2, dtype=torch.bool, device=device ), ], dim=-1, @@ -224,32 +219,32 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict # Current depot index if self.start_mode == "random": current_depot = torch.randint( - low=0, high=self.num_depot, size=(*batch_size, 1), device=self.device + low=0, high=self.generator.num_depot, size=(*batch_size, 1), device=device ) elif self.start_mode == "order": - current_depot = torch.zeros((*batch_size, 1), dtype=torch.int64, device=self.device) + current_depot = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) # Current carry order number - current_carry = torch.zeros((*batch_size, 1), dtype=torch.int64, device=self.device) + current_carry = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) # Current length of each depot - current_length = torch.zeros((*batch_size, self.num_depot), dtype=torch.float32, device=self.device) + current_length = torch.zeros((*batch_size, self.generator.num_depot), dtype=torch.float32, device=device) # Arrive time for each city - arrivetime_record = torch.zeros((*batch_size, self.num_loc + self.num_depot), dtype=torch.float32, device=self.device) + arrivetime_record = torch.zeros((*batch_size, self.generator.num_loc + self.generator.num_depot), dtype=torch.float32, device=device) # Cannot visit depot at first step # [0,1...1] so set not available available = torch.ones( - (*batch_size, self.num_loc + self.num_depot), dtype=torch.bool, device=self.device + (*batch_size, self.generator.num_loc + self.generator.num_depot), dtype=torch.bool, device=device ) action_mask = ~available.contiguous() # [batch_size, graph_size+1] action_mask[..., 0] = 1 # First step is always the depot # Other variables current_node = torch.zeros( - (*batch_size, 1), dtype=torch.int64, device=self.device + (*batch_size, 1), dtype=torch.int64, device=device ) - i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=self.device) + i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) return TensorDict( { @@ -270,13 +265,13 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size=batch_size, ) - def _make_spec(self, td_params: TensorDict): + def _make_spec(self, generator: MDCPDPGenerator): """Make the observation and action specs from the parameters.""" self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc + 1, 2), + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.num_loc + 1, 2), dtype=torch.float32, ), current_node=UnboundedDiscreteTensorSpec( @@ -292,7 +287,7 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.int64, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1), + shape=(generator.num_loc + 1), dtype=torch.bool, ), shape=(), @@ -301,7 +296,7 @@ def _make_spec(self, td_params: TensorDict): shape=(1,), dtype=torch.int64, low=0, - high=self.num_loc + 1, + high=generator.num_loc + 1, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) @@ -315,7 +310,7 @@ def get_distance(self, prev_loc, cur_loc): else: raise ValueError(f"Invalid distance norm: {self.dist_norm}") - def get_reward(self, td: TensorDict, actions) -> TensorDict: + def _get_reward(self, td: TensorDict, actions) -> TensorDict: """Return the rewrad for the current state Support modes: - minmax: the reward is the maximum length of all agents @@ -350,178 +345,10 @@ def get_reward(self, td: TensorDict, actions) -> TensorDict: raise NotImplementedError(f"Invalid reward mode: {self.reward_mode}. Available modes: minmax, minsum, lateness_square, lateness") return -cost # minus for reward - def generate_data(self, batch_size) -> TensorDict: - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - num_orders = self.num_loc // 2 - - # Pickup locations - pickup_locs = ( - torch.FloatTensor(*batch_size, num_orders, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Delivery locations - delivery_locs = ( - torch.FloatTensor(*batch_size, num_orders, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Depots locations - if self.depot_mode == "single": - depot_locs = ( - torch.FloatTensor(*batch_size, 1, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ).repeat(1, self.num_depot, 1) - elif self.depot_mode == "multiple": - depot_locs = ( - torch.FloatTensor(*batch_size, self.num_depot, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Capacity - capacity = torch.randint( - low=self.min_capacity, - high=self.max_capacity + 1, - size=(*batch_size, self.num_depot), - ) - - # Lateness weight - lateness_weight = ( - torch.FloatTensor(*batch_size, 1) - .uniform_(self.min_lateness_weight, self.max_lateness_weight) - .to(self.device) - ) - - - return TensorDict( - { - "locs": torch.cat([pickup_locs, delivery_locs], dim=-2), # No depot - "depot": depot_locs, - "capacity": capacity, - "lateness_weight": lateness_weight, - }, - batch_size=batch_size, - ) - @staticmethod - def render(td: TensorDict, actions=None, ax=None): - import matplotlib.pyplot as plt - markersize = 8 - - td = td.detach().cpu() - - # If batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - if actions is not None: - actions = actions[0] - - n_depots = td["capacity"].size(-1) - n_pickups = (td["locs"].size(-2) - n_depots) // 2 - - # Variables - init_deliveries = td["to_deliver"][n_depots:] - delivery_locs = td["locs"][n_depots:][~init_deliveries.bool()] - pickup_locs = td["locs"][n_depots:][init_deliveries.bool()] - depot_locs = td["locs"][:n_depots] - actions = actions if actions is not None else td["action"] - - if ax is None: - _, ax = plt.subplots(figsize=(4, 4)) - - # Plot the actions in order - last_depot = 0 - for i in range(len(actions)-1): - if actions[i+1] < n_depots: - last_depot = actions[i+1] - if actions[i] < n_depots and actions[i+1] < n_depots: - continue - from_node = actions[i] - to_node = ( - actions[i + 1] if i < len(actions) - 1 else actions[0] - ) # last goes back to depot - from_loc = td["locs"][from_node] - to_loc = td["locs"][to_node] - ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], "k-") - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="->", color="black"), - annotation_clip=False, - ) - - # Plot last back to the depot - from_node = actions[-1] - to_node = last_depot - from_loc = td["locs"][from_node] - to_loc = td["locs"][to_node] - ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], "k-") - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="->", color="black"), - annotation_clip=False, - ) - - # Annotate node location - for i, loc in enumerate(td["locs"]): - ax.annotate( - str(i), - (loc[0], loc[1]), - textcoords="offset points", - xytext=(0, 5), - ha="center", - ) - - for i, depot_loc in enumerate(depot_locs): - ax.plot( - depot_loc[0], - depot_loc[1], - "tab:green", - marker="s", - markersize=markersize, - label="Depot" if i == 0 else None, - ) - - # Plot the pickup locations - for i, pickup_loc in enumerate(pickup_locs): - ax.plot( - pickup_loc[0], - pickup_loc[1], - "tab:red", - marker="^", - markersize=markersize, - label="Pickup" if i == 0 else None, - ) + def check_solution_validity(td: TensorDict, actions: torch.Tensor): + assert True, "Not implemented" - # Plot the delivery locations - for i, delivery_loc in enumerate(delivery_locs): - ax.plot( - delivery_loc[0], - delivery_loc[1], - "tab:blue", - marker="x", - markersize=markersize, - label="Delivery" if i == 0 else None, - ) - - # Plot pickup and delivery pair: from loc[n_depot + i ] to loc[n_depot + n_pickups + i] - for i in range(n_pickups): - pickup_loc = td["locs"][n_depots + i] - delivery_loc = td["locs"][n_depots + n_pickups + i] - ax.plot( - [pickup_loc[0], delivery_loc[0]], - [pickup_loc[1], delivery_loc[1]], - "k--", - alpha=0.5, - ) - - # Setup limits and show - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) + @staticmethod + def render(td: TensorDict, actions: torch.Tensor=None, ax = None): + return render(td, actions, ax) diff --git a/rl4co/envs/routing/mdcpdp/generator.py b/rl4co/envs/routing/mdcpdp/generator.py new file mode 100644 index 00000000..284e46a0 --- /dev/null +++ b/rl4co/envs/routing/mdcpdp/generator.py @@ -0,0 +1,126 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class MDCPDPGenerator(Generator): + """Data generator for the Multi Depot Capacitated Pickup and Delivery Problem (MDCPDP) environment. + + Args: + num_loc: number of locations (customers) + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates, default is 150 insted of 1.0, will be scaled + loc_distribution: distribution for the location coordinates + num_depot: number of depots, each depot has one vehicle + depot_mode: mode for the depot, either single or multiple + depod_distribution: distribution for the depot coordinates + min_capacity: minimum value of the capacity + max_capacity: maximum value of the capacity + min_lateness_weight: minimum value of the lateness weight + max_lateness_weight: maximum value of the lateness weight + latebess_weight_distribution: distribution for the lateness weight + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + depot [batch_size, num_depot, 2]: locations of each depot + capacity [batch_size, 1]: capacity of the vehicle + lateness_weight [batch_size, 1]: weight of the lateness cost + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 1.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + num_depot: int = 5, + depot_mode: str = "multiple", + depot_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + min_capacity: int = 1, + max_capacity: int = 5, + min_lateness_weight: float = 1.0, + max_lateness_weight: float = 1.0, + lateness_weight_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + **kwargs + ): + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.depot_mode = depot_mode + self.num_depot = num_depot + self.min_capacity = min_capacity + self.max_capacity = max_capacity + self.min_lateness_weight = min_lateness_weight + self.max_lateness_weight = max_lateness_weight + + # Number of locations must be even + if num_loc % 2 != 0: + log.warn("Number of locations must be even. Adding 1 to the number of locations.") + self.num_loc += 1 + + # Check depot mode validity + assert depot_mode in ["single", "multiple"], f"Invalid depot mode: {depot_mode}" + + # Location distribution + if kwargs.get("loc_sampler", None) is not None: + self.loc_sampler = kwargs["loc_sampler"] + else: + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) + + # Depot distribution + if kwargs.get("depot_sampler", None) is not None: + self.depot_sampler = kwargs["depot_sampler"] + else: + self.depot_sampler = get_sampler("depot", depot_distribution, min_loc, max_loc, **kwargs) + + # Lateness weight distribution + if kwargs.get("lateness_weight_sampler", None) is not None: + self.lateness_weight_sampler = kwargs["lateness_weight_sampler"] + else: + self.lateness_weight_sampler = get_sampler( + "lateness_weight", lateness_weight_distribution, min_lateness_weight, max_lateness_weight, **kwargs + ) + + def _generate(self, batch_size) -> TensorDict: + # Sample locations + locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2)) + + # Sample depot + if self.depot_mode == "single": + depot = self.depot_sampler.sample((*batch_size, 2))[:, None, :].repeat(1, self.num_depot, 1) + else: + depot = self.depot_sampler.sample((*batch_size, self.num_depot, 2)) + + # Sample capacity + capacity = torch.randint( + self.min_capacity, + self.max_capacity + 1, + size=(*batch_size, 1), + ) + + # Sample lateness weight + lateness_weight = self.lateness_weight_sampler.sample((*batch_size, 1)) + + return TensorDict( + { + "locs": locs, + "depot": depot, + "capacity": capacity, + "lateness_weight": lateness_weight, + }, + batch_size=batch_size, + ) diff --git a/rl4co/envs/routing/mdcpdp/render.py b/rl4co/envs/routing/mdcpdp/render.py new file mode 100644 index 00000000..5711b1d7 --- /dev/null +++ b/rl4co/envs/routing/mdcpdp/render.py @@ -0,0 +1,120 @@ +from tensordict.tensordict import TensorDict + + +def render(td: TensorDict, actions=None, ax=None): + import matplotlib.pyplot as plt + markersize = 8 + + td = td.detach().cpu() + + # If batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + if actions is not None: + actions = actions[0] + + n_depots = td["capacity"].size(-1) + n_pickups = (td["locs"].size(-2) - n_depots) // 2 + + # Variables + init_deliveries = td["to_deliver"][n_depots:] + delivery_locs = td["locs"][n_depots:][~init_deliveries.bool()] + pickup_locs = td["locs"][n_depots:][init_deliveries.bool()] + depot_locs = td["locs"][:n_depots] + actions = actions if actions is not None else td["action"] + + if ax is None: + _, ax = plt.subplots(figsize=(4, 4)) + + # Plot the actions in order + last_depot = 0 + for i in range(len(actions)-1): + if actions[i+1] < n_depots: + last_depot = actions[i+1] + if actions[i] < n_depots and actions[i+1] < n_depots: + continue + from_node = actions[i] + to_node = ( + actions[i + 1] if i < len(actions) - 1 else actions[0] + ) # last goes back to depot + from_loc = td["locs"][from_node] + to_loc = td["locs"][to_node] + ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], "k-") + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="->", color="black"), + annotation_clip=False, + ) + + # Plot last back to the depot + from_node = actions[-1] + to_node = last_depot + from_loc = td["locs"][from_node] + to_loc = td["locs"][to_node] + ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], "k-") + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="->", color="black"), + annotation_clip=False, + ) + + # Annotate node location + for i, loc in enumerate(td["locs"]): + ax.annotate( + str(i), + (loc[0], loc[1]), + textcoords="offset points", + xytext=(0, 5), + ha="center", + ) + + for i, depot_loc in enumerate(depot_locs): + ax.plot( + depot_loc[0], + depot_loc[1], + "tab:green", + marker="s", + markersize=markersize, + label="Depot" if i == 0 else None, + ) + + # Plot the pickup locations + for i, pickup_loc in enumerate(pickup_locs): + ax.plot( + pickup_loc[0], + pickup_loc[1], + "tab:red", + marker="^", + markersize=markersize, + label="Pickup" if i == 0 else None, + ) + + # Plot the delivery locations + for i, delivery_loc in enumerate(delivery_locs): + ax.plot( + delivery_loc[0], + delivery_loc[1], + "tab:blue", + marker="x", + markersize=markersize, + label="Delivery" if i == 0 else None, + ) + + # Plot pickup and delivery pair: from loc[n_depot + i ] to loc[n_depot + n_pickups + i] + for i in range(n_pickups): + pickup_loc = td["locs"][n_depots + i] + delivery_loc = td["locs"][n_depots + n_pickups + i] + ax.plot( + [pickup_loc[0], delivery_loc[0]], + [pickup_loc[1], delivery_loc[1]], + "k--", + alpha=0.5, + ) + + # Setup limits and show + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/mpdp.py b/rl4co/envs/routing/mpdp/env.py similarity index 69% rename from rl4co/envs/routing/mpdp.py rename to rl4co/envs/routing/mpdp/env.py index 56e9e93b..37e19b70 100644 --- a/rl4co/envs/routing/mpdp.py +++ b/rl4co/envs/routing/mpdp/env.py @@ -14,51 +14,57 @@ from rl4co.utils.ops import gather_by_index from rl4co.utils.pylogger import get_pylogger +from .generator import MPDPGenerator +from .render import render + log = get_pylogger(__name__) class MPDPEnv(RL4COEnvBase): - """Multi-agent Pickup and Delivery Problem environment. + """Multi-agent Pickup and Delivery Problem (mPDP) environment. The goal is to pick up and deliver all the packages while satisfying the precedence constraints. When an agent goes back to the depot, a new agent is spawned. In the min-max version, the goal is to minimize the - maximum tour length among all agents. - The reward is 0 unless the agent visits all the cities. + maximum tour length among all agents. The reward is 0 unless the agent visits all the customers. In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. + Observations: + - locations of the depot, pickup, and delivery locations + - current location of the vehicle + - the remaining locations to deliver + - the visited locations + - the current step + + Constraints: + - the tour starts and ends at the depot + - each pickup location must be visited before its corresponding delivery location + - the vehicle cannot visit the same location twice + + Finish Condition: + - the vehicle has visited all locations + + Reward: + - (minus) the negative length of the path + Args: - num_loc: number of locations (cities) in the TSP - min_loc: minimum location coordinate. Used for data generation - max_loc: maximum location coordinate. Used for data generation - min_num_agents: minimum number of agents. Used for data generation - max_num_agents: maximum number of agents. Used for data generation - objective: objective to optimize. Either 'minmax' or 'minsum' - check_solution: whether to check the validity of the solution - td_params: parameters of the environment + generator: MPDPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "mpdp" def __init__( self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - min_num_agents: int = 2, - max_num_agents: int = 10, + generator: MPDPGenerator = None, + generator_params: dict = {}, objective: str = "minmax", - check_solution: bool = False, - td_params: TensorDict = None, **kwargs, ): super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self.min_num_agents = min_num_agents - self.max_num_agents = max_num_agents + if generator is None: + generator = MPDPGenerator(**generator_params) + self.generator = generator self.objective = objective - self.check_solution = check_solution - self._make_spec(td_params) + self._make_spec(self.generator) def _step(self, td: TensorDict) -> TensorDict: selected = td["action"][:, None] # Add dimension for step @@ -137,13 +143,7 @@ def _reset( batch_size: Optional[list] = None, agent_num: Optional[int] = None, # NOTE hardcoded from ET ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - - self.to(td.device) + device = td.device # NOTE: this is a hack to get the agent_num # agent_num = td["agent_num"][0].item() if agent_num is None else agent_num @@ -157,7 +157,7 @@ def _reset( # Distance from all nodes between each other distance = torch.cdist(whole_instance, whole_instance, p=2) - index = torch.arange(left_request, 2 * left_request, device=depot.device)[ + index = torch.arange(left_request, 2 * left_request, device=device)[ None, :, None ] index = index.repeat(distance.shape[0], 1, 1) @@ -195,10 +195,10 @@ def _reset( 1, n_loc // 2 + agent_num + 1, dtype=torch.uint8, - device=loc.device, + device=device, ), torch.zeros( - batch_size, 1, n_loc // 2, dtype=torch.uint8, device=loc.device + batch_size, 1, n_loc // 2, dtype=torch.uint8, device=device ), ], dim=-1, @@ -213,25 +213,25 @@ def _reset( 1, n_loc + agent_num + 1, dtype=torch.uint8, - device=loc.device, + device=device, ), - "lengths": torch.zeros(batch_size, agent_num, device=loc.device), - "longest_lengths": torch.zeros(batch_size, agent_num, device=loc.device), + "lengths": torch.zeros(batch_size, agent_num, device=device), + "longest_lengths": torch.zeros(batch_size, agent_num, device=device), "cur_coord": td["depot"] if len(td["depot"].shape) == 2 else td["depot"].squeeze(1), "i": torch.zeros( - batch_size, dtype=torch.int64, device=loc.device + batch_size, dtype=torch.int64, device=device ), # Vector with length num_steps "to_delivery": to_delivery, "count_depot": torch.zeros( - batch_size, 1, dtype=torch.int64, device=loc.device + batch_size, 1, dtype=torch.int64, device=device ), "agent_idx": torch.ones( - batch_size, 1, dtype=torch.long, device=loc.device + batch_size, 1, dtype=torch.long, device=device ), "left_request": left_request - * torch.ones(batch_size, 1, dtype=torch.long, device=loc.device), + * torch.ones(batch_size, 1, dtype=torch.long, device=device), "remain_pickup_max_distance": remain_pickup_max_distance, "remain_delivery_max_distance": remain_delivery_max_distance, "depot_distance": depot_distance, @@ -296,11 +296,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: action_mask = mask_loc == 0 # action_mask gets feasible actions return action_mask - def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: - # Check that the solution is valid - if self.check_solution: - self.check_solution_validity(td, actions) - + def _get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: # Calculate the reward (negative tour length) if self.objective == "minmax": return -td["lengths"].max(dim=-1, keepdim=True)[0].squeeze(-1) @@ -313,32 +309,13 @@ def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: def check_solution_validity(td: TensorDict, actions: torch.Tensor): assert True, "Not implemented" - def generate_data(self, batch_size) -> TensorDict: - # Batch size input check - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - # Initialize the locations (including the depot which is always the first node) - locs_with_depot = ( - torch.FloatTensor(*batch_size, self.num_loc + 1, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - return TensorDict( - { - "locs": locs_with_depot[..., 1:, :], - "depot": locs_with_depot[..., 0, :], - }, - batch_size=batch_size, - ) - - def _make_spec(self, td_params: TensorDict): + def _make_spec(self, generator: MPDPGenerator): """Make the observation and action specs from the parameters.""" max_nodes = self.num_loc + self.max_num_agents + 1 self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, + low=generator.min_loc, + high=generator.max_loc, shape=(max_nodes, 2), dtype=torch.float32, ), @@ -355,16 +332,16 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.bool, ), lengths=UnboundedContinuousTensorSpec( - shape=(self.max_num_agents,), + shape=(generator.max_num_agents,), dtype=torch.float32, ), longest_lengths=UnboundedContinuousTensorSpec( - shape=(self.max_num_agents,), + shape=(generator.max_num_agents,), dtype=torch.float32, ), cur_coord=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, + low=generator.min_loc, + high=generator.max_loc, shape=(2,), dtype=torch.float32, ), @@ -424,105 +401,5 @@ def _make_spec(self, td_params: TensorDict): self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) @staticmethod - def render(td: TensorDict, actions=None, ax=None): - # TODO: color switch with new agents; add pickup and delivery nodes as in `PDPEnv.render` - - import matplotlib.pyplot as plt - import numpy as np - - from matplotlib import cm, colormaps - - num_routine = (actions == 0).sum().item() + 2 - base = colormaps["nipy_spectral"] - color_list = base(np.linspace(0, 1, num_routine)) - cmap_name = base.name + str(num_routine) - out = base.from_list(cmap_name, color_list, num_routine) - - if ax is None: - # Create a plot of the nodes - _, ax = plt.subplots() - - td = td.detach().cpu() - - if actions is None: - actions = td.get("action", None) - - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] - - locs = td["locs"] - - # add the depot at the first action and the end action - actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) - - # gather locs in order of action if available - if actions is None: - log.warning("No action in TensorDict, rendering unsorted locs") - else: - locs = locs - - # Cat the first node to the end to complete the tour - x, y = locs[:, 0], locs[:, 1] - - # plot depot - ax.scatter( - locs[0, 0], - locs[0, 1], - edgecolors=cm.Set2(2), - facecolors="none", - s=100, - linewidths=2, - marker="s", - alpha=1, - ) - - # plot visited nodes - ax.scatter( - x[1:], - y[1:], - edgecolors=cm.Set2(0), - facecolors="none", - s=50, - linewidths=2, - marker="o", - alpha=1, - ) - - # text depot - ax.text( - locs[0, 0], - locs[0, 1] - 0.025, - "Depot", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(2), - ) - - # plot actions - color_idx = 0 - for action_idx in range(len(actions) - 1): - if actions[action_idx] == 0: - color_idx += 1 - from_loc = locs[actions[action_idx]] - to_loc = locs[actions[action_idx + 1]] - ax.plot( - [from_loc[0], to_loc[0]], - [from_loc[1], to_loc[1]], - color=out(color_idx), - lw=1, - ) - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), - size=15, - annotation_clip=False, - ) - - # Setup limits and show - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) + def render(td: TensorDict, actions: torch.Tensor=None, ax = None): + return render(td, actions, ax) \ No newline at end of file diff --git a/rl4co/envs/routing/mpdp/generator.py b/rl4co/envs/routing/mpdp/generator.py new file mode 100644 index 00000000..35120305 --- /dev/null +++ b/rl4co/envs/routing/mpdp/generator.py @@ -0,0 +1,90 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class MPDPGenerator(Generator): + """Data generator for the Capacitated Vehicle Routing Problem (CVRP). + Args: + num_loc: number of locations + min_loc: minimum location value + max_loc: maximum location value + loc_distribution: distribution for the locations + depot_distribution: distribution for the depot + min_num_agents: minimum number of agents + max_num_agents: maximum number of agents + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer and the depot + depot [batch_size, 2]: location of the depot + num_agents [batch_size]: number of agents + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 1.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + depot_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + min_num_agents: int = 2, + max_num_agents: int = 10, + **kwargs + ): + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.min_num_agents = min_num_agents + self.max_num_agents = max_num_agents + + # Number of locations must be even + if num_loc % 2 != 0: + log.warn("Number of locations must be even. Adding 1 to the number of locations.") + self.num_loc += 1 + + # Location distribution + if kwargs.get("loc_sampler", None) is not None: + self.loc_sampler = kwargs["loc_sampler"] + else: + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) + + # Depot distribution + if kwargs.get("depot_sampler", None) is not None: + self.depot_sampler = kwargs["depot_sampler"] + else: + self.depot_sampler = get_sampler("depot", depot_distribution, min_loc, max_loc, **kwargs) + + def _generate(self, batch_size) -> TensorDict: + # Sample locations + locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2)) + + # Sample depot + depot = self.depot_sampler.sample((*batch_size, 2)) + + # Sample the number of agents + num_agents = torch.randint( + self.min_num_agents, + self.max_num_agents + 1, + size=(*batch_size, ), + ) + + return TensorDict( + { + "locs": locs, + "depot": depot, + "num_agents": num_agents, + }, + batch_size=batch_size, + ) diff --git a/rl4co/envs/routing/mpdp/render.py b/rl4co/envs/routing/mpdp/render.py new file mode 100644 index 00000000..1f49a2f9 --- /dev/null +++ b/rl4co/envs/routing/mpdp/render.py @@ -0,0 +1,114 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import cm, colormaps + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td, actions=None, ax=None): + # TODO: color switch with new agents; add pickup and delivery nodes as in `PDPEnv.render` + + import matplotlib.pyplot as plt + import numpy as np + + from matplotlib import cm, colormaps + + num_routine = (actions == 0).sum().item() + 2 + base = colormaps["nipy_spectral"] + color_list = base(np.linspace(0, 1, num_routine)) + cmap_name = base.name + str(num_routine) + out = base.from_list(cmap_name, color_list, num_routine) + + if ax is None: + # Create a plot of the nodes + _, ax = plt.subplots() + + td = td.detach().cpu() + + if actions is None: + actions = td.get("action", None) + + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + locs = td["locs"] + + # add the depot at the first action and the end action + actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) + + # gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + locs = locs + + # Cat the first node to the end to complete the tour + x, y = locs[:, 0], locs[:, 1] + + # plot depot + ax.scatter( + locs[0, 0], + locs[0, 1], + edgecolors=cm.Set2(2), + facecolors="none", + s=100, + linewidths=2, + marker="s", + alpha=1, + ) + + # plot visited nodes + ax.scatter( + x[1:], + y[1:], + edgecolors=cm.Set2(0), + facecolors="none", + s=50, + linewidths=2, + marker="o", + alpha=1, + ) + + # text depot + ax.text( + locs[0, 0], + locs[0, 1] - 0.025, + "Depot", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(2), + ) + + # plot actions + color_idx = 0 + for action_idx in range(len(actions) - 1): + if actions[action_idx] == 0: + color_idx += 1 + from_loc = locs[actions[action_idx]] + to_loc = locs[actions[action_idx + 1]] + ax.plot( + [from_loc[0], to_loc[0]], + [from_loc[1], to_loc[1]], + color=out(color_idx), + lw=1, + ) + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), + size=15, + annotation_clip=False, + ) + + # Setup limits and show + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/mtsp.py b/rl4co/envs/routing/mtsp/env.py similarity index 58% rename from rl4co/envs/routing/mtsp.py rename to rl4co/envs/routing/mtsp/env.py index 203da7ed..2259be41 100644 --- a/rl4co/envs/routing/mtsp.py +++ b/rl4co/envs/routing/mtsp/env.py @@ -15,6 +15,9 @@ from rl4co.envs.common.utils import batch_to_scalar from rl4co.utils.ops import gather_by_index, get_distance, get_tour_length +from .generator import MTSPGenerator +from .render import render + class MTSPEnv(RL4COEnvBase): """Multiple Traveling Salesman Problem environment @@ -24,37 +27,45 @@ class MTSPEnv(RL4COEnvBase): - `sum`: the cost is the sum of the path lengths of all the agents Reward is - cost, so the goal is to maximize the reward (minimize the cost). + Observations: + - locations of the depot and each customer. + - number of agents. + - the current agent index. + - the current location of the vehicle. + + Constrains: + - each agent's tour starts and ends at the depot. + - each customer must be visited exactly once. + + Finish condition: + - all customers are visited and all agents back to the depot. + + Reward: + There are two ways to calculate the cost (-reward): + - `minmax`: (default) the cost is the maximum of the path lengths of all the agents. + - `sum`: the cost is the sum of the path lengths of all the agents. + Args: - num_loc: number of locations (cities) to visit - min_loc: minimum value of the locations - max_loc: maximum value of the locations - min_num_agents: minimum number of agents - max_num_agents: maximum number of agents cost_type: type of cost to use, either `minmax` or `sum` - td_params: parameters for the TensorDict specs + generator: MTSPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "mtsp" def __init__( self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - min_num_agents: int = 5, - max_num_agents: int = 5, + generator: MTSPGenerator = None, + generator_params: dict = {}, cost_type: str = "minmax", - td_params: TensorDict = None, **kwargs, ): super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self.min_num_agents = min_num_agents - self.max_num_agents = max_num_agents + if generator is None: + generator = MTSPGenerator(**generator_params) + self.generator = generator self.cost_type = cost_type - self._make_spec(td_params) + self._make_spec(self.generator) @staticmethod def _step(td: TensorDict) -> TensorDict: @@ -128,13 +139,7 @@ def _step(td: TensorDict) -> TensorDict: return td def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: - # Initialize data - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - - device = td.device if td is not None else self.device - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) + device = td.device # Keep track of the agent number to know when to stop agent_idx = torch.zeros((*batch_size,), dtype=torch.int64, device=device) @@ -146,7 +151,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict # Other variables current_node = torch.zeros((*batch_size,), dtype=torch.int64, device=device) available = torch.ones( - (*batch_size, self.num_loc), dtype=torch.bool, device=device + (*batch_size, self.generator.num_loc), dtype=torch.bool, device=device ) # 1 means not visited, i.e. action is allowed available[..., 0] = 0 # Depot is not available as first node i = torch.zeros((*batch_size,), dtype=torch.int64, device=device) @@ -166,13 +171,13 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size=batch_size, ) - def _make_spec(self, td_params: TensorDict): + def _make_spec(self, generator: MTSPGenerator): """Make the observation and action specs from the parameters.""" self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc, 2), + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.num_loc, 2), dtype=torch.float32, ), num_agents=UnboundedDiscreteTensorSpec( @@ -204,7 +209,7 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.int64, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc), + shape=(generator.num_loc), dtype=torch.bool, ), shape=(), @@ -213,12 +218,12 @@ def _make_spec(self, td_params: TensorDict): shape=(1,), dtype=torch.int64, low=0, - high=self.num_loc, + high=generator.num_loc, ) self.reward_spec = UnboundedContinuousTensorSpec() self.done_spec = UnboundedDiscreteTensorSpec(dtype=torch.bool) - def get_reward(self, td, actions=None) -> TensorDict: + def _get_reward(self, td, actions=None) -> TensorDict: # With minmax, get the maximum distance among subtours, calculated in the model if self.cost_type == "minmax": return td["reward"].squeeze(-1) @@ -232,124 +237,10 @@ def get_reward(self, td, actions=None) -> TensorDict: else: raise ValueError(f"Cost type {self.cost_type} not supported") - def generate_data(self, batch_size) -> TensorDict: - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - # Initialize the locations (including the depot which is always the first node) - locs = ( - torch.FloatTensor(*batch_size, self.num_loc, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Initialize the num_agents: either fixed or random integer between min and max - if self.min_num_agents == self.max_num_agents: - num_agents = ( - torch.ones(batch_size, dtype=torch.int64, device=self.device) - * self.min_num_agents - ) - else: - num_agents = torch.randint( - self.min_num_agents, - self.max_num_agents, - size=batch_size, - device=self.device, - ) - - return TensorDict( - { - "locs": locs, - "num_agents": num_agents, - }, - batch_size=batch_size, - ) + @staticmethod + def check_solution_validity(td: TensorDict, actions: torch.Tensor): + assert True, "Not implemented" @staticmethod def render(td, actions=None, ax=None): - import matplotlib.pyplot as plt - - from matplotlib import colormaps - - def discrete_cmap(num, base_cmap="nipy_spectral"): - """Create an N-bin discrete colormap from the specified input map""" - base = colormaps[base_cmap] - color_list = base(np.linspace(0, 1, num)) - cmap_name = base.name + str(num) - return base.from_list(cmap_name, color_list, num) - - if actions is None: - actions = td.get("action", None) - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] - - num_agents = td["num_agents"] - locs = td["locs"] - cmap = discrete_cmap(num_agents, "rainbow") - - fig, ax = plt.subplots() - - # Add depot action = 0 to before first action and after last action - actions = torch.cat( - [ - torch.zeros(1, dtype=torch.int64), - actions, - torch.zeros(1, dtype=torch.int64), - ] - ) - - # Make list of colors from matplotlib - for i, loc in enumerate(locs): - if i == 0: - # depot - marker = "s" - color = "g" - label = "Depot" - markersize = 10 - else: - # normal location - marker = "o" - color = "tab:blue" - label = "Cities" - markersize = 8 - if i > 1: - label = "" - - ax.plot( - loc[0], - loc[1], - color=color, - marker=marker, - markersize=markersize, - label=label, - ) - - # Plot the actions in order - agent_idx = 0 - for i in range(len(actions)): - if actions[i] == 0: - agent_idx += 1 - color = cmap(num_agents - agent_idx) - - from_node = actions[i] - to_node = ( - actions[i + 1] if i < len(actions) - 1 else actions[0] - ) # last goes back to depot - from_loc = td["locs"][from_node] - to_loc = td["locs"][to_node] - ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], color=color) - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="->", color=color), - annotation_clip=False, - ) - - # Legend - handles, labels = ax.get_legend_handles_labels() - ax.legend(handles, labels) - ax.set_title("mTSP") - ax.set_xlabel("x-coordinate") - ax.set_ylabel("y-coordinate") + return render(td, actions, ax) diff --git a/rl4co/envs/routing/mtsp/generator.py b/rl4co/envs/routing/mtsp/generator.py new file mode 100644 index 00000000..871059ed --- /dev/null +++ b/rl4co/envs/routing/mtsp/generator.py @@ -0,0 +1,71 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class MTSPGenerator(Generator): + """Data generator for the Multiple Travelling Salesman Problem (mTSP). + Args: + num_loc: number of locations (customers) in the TSP + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + loc_distribution: distribution for the location coordinates + min_num_agents: minimum number of agents (vehicles), include + max_num_agents: maximum number of agents (vehicles), include + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + num_agents [batch_size]: number of agents (vehicles) + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 1.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + min_num_agents: int = 5, + max_num_agents: int = 5, + **kwargs + ): + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.min_num_agents = min_num_agents + self.max_num_agents = max_num_agents + + # Location distribution + if kwargs.get("loc_sampler", None) is not None: + self.loc_sampler = kwargs["loc_sampler"] + else: + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) + + def _generate(self, batch_size) -> TensorDict: + # Sample locations + locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2)) + + # Sample the number of agents + num_agents = torch.randint( + self.min_num_agents, + self.max_num_agents + 1, + size=(*batch_size, ), + ) + + return TensorDict( + { + "locs": locs, + "num_agents": num_agents, + }, + batch_size=batch_size, + ) + \ No newline at end of file diff --git a/rl4co/envs/routing/mtsp/render.py b/rl4co/envs/routing/mtsp/render.py new file mode 100644 index 00000000..173301ae --- /dev/null +++ b/rl4co/envs/routing/mtsp/render.py @@ -0,0 +1,95 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import colormaps + +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td, actions=None, ax=None): + def discrete_cmap(num, base_cmap="nipy_spectral"): + """Create an N-bin discrete colormap from the specified input map""" + base = colormaps[base_cmap] + color_list = base(np.linspace(0, 1, num)) + cmap_name = base.name + str(num) + return base.from_list(cmap_name, color_list, num) + + if actions is None: + actions = td.get("action", None) + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + num_agents = td["num_agents"] + locs = td["locs"] + cmap = discrete_cmap(num_agents, "rainbow") + + fig, ax = plt.subplots() + + # Add depot action = 0 to before first action and after last action + actions = torch.cat( + [ + torch.zeros(1, dtype=torch.int64), + actions, + torch.zeros(1, dtype=torch.int64), + ] + ) + + # Make list of colors from matplotlib + for i, loc in enumerate(locs): + if i == 0: + # depot + marker = "s" + color = "g" + label = "Depot" + markersize = 10 + else: + # normal location + marker = "o" + color = "tab:blue" + label = "Customers" + markersize = 8 + if i > 1: + label = "" + + ax.plot( + loc[0], + loc[1], + color=color, + marker=marker, + markersize=markersize, + label=label, + ) + + # Plot the actions in order + agent_idx = 0 + for i in range(len(actions)): + if actions[i] == 0: + agent_idx += 1 + color = cmap(num_agents - agent_idx) + + from_node = actions[i] + to_node = ( + actions[i + 1] if i < len(actions) - 1 else actions[0] + ) # last goes back to depot + from_loc = td["locs"][from_node] + to_loc = td["locs"][to_node] + ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], color=color) + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="->", color=color), + annotation_clip=False, + ) + + # Legend + handles, labels = ax.get_legend_handles_labels() + ax.legend(handles, labels) + ax.set_title("mTSP") + ax.set_xlabel("x-coordinate") + ax.set_ylabel("y-coordinate") diff --git a/rl4co/envs/routing/op.py b/rl4co/envs/routing/op/env.py similarity index 53% rename from rl4co/envs/routing/op.py rename to rl4co/envs/routing/op/env.py index 98d016d7..bfce8da8 100644 --- a/rl4co/envs/routing/op.py +++ b/rl4co/envs/routing/op/env.py @@ -15,11 +15,10 @@ from rl4co.utils.ops import gather_by_index, get_tour_length from rl4co.utils.pylogger import get_pylogger -log = get_pylogger(__name__) - +from .generator import OPGenerator +from .render import render -# From Kool et al. 2019 -MAX_LENGTHS = {20: 2.0, 50: 3.0, 100: 4.0} +log = get_pylogger(__name__) class OPEnv(RL4COEnvBase): @@ -27,48 +26,50 @@ class OPEnv(RL4COEnvBase): At each step, the agent chooses a location to visit in order to maximize the collected prize. The total length of the path must not exceed a given threshold. + Observations: + - location of the depot + - locations and prize of each customer + - current location of the vehicle + - current tour length + - current total prize + - the remaining length of the path + + Constraints: + - the tour starts and ends at the depot + - not all customers need to be visited + - the vehicle cannot visit customers exceed the remaining length of the path + + Finish Condition: + - the vehicle back to the depot + + Reward: + - the sum of the prizes of visited nodes + Args: - num_loc: number of locations (cities) in the OP - min_loc: minimum value of the locations - max_loc: maximum value of the locations - max_length: maximum length of the path - prize_type: type of prize to collect. Can be: - - "dist": the prize is the distance from the previous location - - "unif": the prize is a uniform random variable - - "const": the prize is a constant - td_params: parameters of the environment + generator: OPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "op" def __init__( self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - max_length: Union[float, torch.Tensor] = None, + generator: OPGenerator = None, + generator_params: dict = {}, prize_type: str = "dist", - td_params: TensorDict = None, **kwargs, ): super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self.max_length = ( - MAX_LENGTHS.get(num_loc, None) if max_length is None else max_length - ) - if self.max_length is None: - raise ValueError( - f"`max_length` must be specified for num_loc={num_loc}. Please specify it manually." - ) + if generator is None: + generator = OPGenerator(**generator_params) + self.generator = generator self.prize_type = prize_type assert self.prize_type in [ "dist", "unif", "const", ], f"Invalid prize_type: {self.prize_type}" - self._make_spec(td_params) + self._make_spec(self.generator) def _step(self, td: TensorDict) -> TensorDict: current_node = td["action"][:, None] @@ -111,12 +112,8 @@ def _reset( td: Optional[TensorDict] = None, batch_size: Optional[list] = None, ) -> TensorDict: - # Initialize params - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - self.to(td.device) + device = td.device + # Add depot to locs locs_with_depot = torch.cat((td["depot"][:, None, :], td["locs"]), -2) @@ -127,25 +124,25 @@ def _reset( "prize": F.pad( td["prize"], (1, 0), mode="constant", value=0 ), # add 0 for depot - "tour_length": torch.zeros(*batch_size, device=self.device), + "tour_length": torch.zeros(*batch_size, device=device), # max_length is max length allowed when arriving at node, so subtract distance to return to depot # Additionally, substract epsilon margin for numeric stability "max_length": td["max_length"][..., None] - (td["depot"][..., None, :] - locs_with_depot).norm(p=2, dim=-1) - 1e-6, "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device + *batch_size, 1, dtype=torch.long, device=device ), "visited": torch.zeros( (*batch_size, locs_with_depot.shape[-2]), dtype=torch.bool, - device=self.device, + device=device, ), "current_total_prize": torch.zeros( - *batch_size, 1, dtype=torch.float, device=self.device + *batch_size, 1, dtype=torch.float, device=device ), "i": torch.zeros( - (*batch_size,), dtype=torch.int64, device=self.device + (*batch_size,), dtype=torch.int64, device=device ), # counter }, batch_size=batch_size, @@ -171,17 +168,13 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: action_mask[..., 0] = 1 return action_mask - def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: + def _get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: """Reward is the sum of the prizes of visited nodes""" # In case all tours directly return to depot, prevent further problems if actions.size(-1) == 1: assert (actions == 0).all(), "If all length 1 tours, they should be zero" return torch.zeros(actions.size(0), dtype=torch.float, device=actions.device) - # Check that the solution is valid - if self.check_solution: - self.check_solution_validity(td, actions) - # Prize is the sum of the prizes of the visited nodes. Note that prize is padded with 0 for depot at index 0 collected_prize = td["prize"].gather(1, actions) return collected_prize.sum(-1) @@ -220,62 +213,13 @@ def check_solution_validity( (length[..., None] - max_length).max() ) - def generate_data(self, batch_size, prize_type=None) -> TensorDict: - # Batch size input check - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - prize_type = self.prize_type if prize_type is None else prize_type - - # Initialize the locations (including the depot which is always the first node) - locs_with_depot = ( - torch.FloatTensor(*batch_size, self.num_loc + 1, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Methods taken from Fischetti et al. (1998) and Kool et al. (2019) - if prize_type == "const": - prize = torch.ones(*batch_size, self.num_loc, device=self.device) - elif prize_type == "unif": - prize = ( - 1 - + torch.randint( - 0, 100, (*batch_size, self.num_loc), device=self.device - ).float() - ) / 100 - elif prize_type == "dist": # based on the distance to the depot - prize = (locs_with_depot[..., 0:1, :] - locs_with_depot[..., 1:, :]).norm( - p=2, dim=-1 - ) - prize = ( - 1 + (prize / prize.max(dim=-1, keepdim=True)[0] * 99).int() - ).float() / 100 - else: - raise ValueError(f"Invalid prize_type: {self.prize_type}") - - # Support for heterogeneous max length if provided - if not isinstance(self.max_length, torch.Tensor): - max_length = torch.full((*batch_size,), self.max_length, device=self.device) - else: - max_length = self.max_length - - return TensorDict( - { - "locs": locs_with_depot[..., 1:, :], - "depot": locs_with_depot[..., 0, :], - "prize": prize, - "max_length": max_length, - }, - batch_size=batch_size, - ) - - def _make_spec(self, td_params: TensorDict): + def _make_spec(self, generator: OPGenerator): """Make the observation and action specs from the parameters.""" self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc + 1, 2), + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.num_loc + 1, 2), dtype=torch.float32, ), current_node=UnboundedDiscreteTensorSpec( @@ -283,15 +227,15 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.int64, ), prize=UnboundedContinuousTensorSpec( - shape=(self.num_loc,), + shape=(generator.num_loc,), dtype=torch.float32, ), tour_length=UnboundedContinuousTensorSpec( - shape=(self.num_loc,), + shape=(generator.num_loc,), dtype=torch.float32, ), visited=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1,), + shape=(generator.num_loc + 1,), dtype=torch.bool, ), max_length=UnboundedContinuousTensorSpec( @@ -299,7 +243,7 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.float32, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1, 1), + shape=(generator.num_loc + 1, 1), dtype=torch.bool, ), shape=(), @@ -308,86 +252,11 @@ def _make_spec(self, td_params: TensorDict): shape=(1,), dtype=torch.int64, low=0, - high=self.num_loc + 1, + high=generator.num_loc + 1, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) @staticmethod - def render(td: TensorDict, actions=None, ax=None): - import matplotlib.pyplot as plt - import numpy as np - - # Create a plot of the nodes - if ax is None: - _, ax = plt.subplots() - - td = td.detach().cpu() - - # Actions - if actions is None: - actions = td.get("action", None) - actions = actions.detach().cpu() if actions is not None else None - - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] if actions is not None else None - - # Variables - depot = td["locs"][0, :] - cities = td["locs"][1:, :] - prizes = td["prize"][1:] - normalized_prizes = ( - 200 * (prizes - torch.min(prizes)) / (torch.max(prizes) - torch.min(prizes)) - + 10 - ) - - # Plot depot and cities with prize - ax.scatter( - depot[0], - depot[1], - marker="s", - c="tab:green", - edgecolors="black", - zorder=5, - s=100, - ) # Plot depot as square - ax.scatter( - cities[:, 0], - cities[:, 1], - s=normalized_prizes, - c=normalized_prizes, - cmap="autumn_r", - alpha=0.6, - edgecolors="black", - ) # Plot all cities with size and color indicating the prize - - # Gather locs in order of action if available - if actions is None: - log.warning("No action in TensorDict, rendering unsorted locs") - else: - # Reorder the cities and their corresponding prizes based on actions - tour = cities[actions - 1] # subtract 1 to match Python's 0-indexing - - # Append the depot at the beginning and the end of the tour - tour = np.vstack((depot, tour, depot)) - - # Use quiver to plot the tour - dx, dy = np.diff(tour[:, 0]), np.diff(tour[:, 1]) - ax.quiver( - tour[:-1, 0], - tour[:-1, 1], - dx, - dy, - scale_units="xy", - angles="xy", - scale=1, - zorder=2, - color="black", - width=0.0035, - ) - - # Setup limits and show - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) + def render(td: TensorDict, actions: torch.Tensor=None, ax = None): + return render(td, actions, ax) diff --git a/rl4co/envs/routing/op/generator.py b/rl4co/envs/routing/op/generator.py new file mode 100644 index 00000000..4cdd49fd --- /dev/null +++ b/rl4co/envs/routing/op/generator.py @@ -0,0 +1,140 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + +# From Kool et al. 2019 +MAX_LENGTHS = {20: 2.0, 50: 3.0, 100: 4.0} + + +class OPGenerator(Generator): + """Data generator for the Orienteering Problem (OP). + Args: + num_loc: number of locations (customers) in the OP, without the depot. (e.g. 10 means 10 locs + 1 depot) + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + loc_distribution: distribution for the location coordinates + min_prize: minimum value for the prize of each customer + max_prize: maximum value for the prize of each customer + prize_distribution: distribution for the prize of each customer + max_length: maximum length of the path + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + depot [batch_size, 2]: location of the depot + prize [batch_size, num_loc]: prize of each customer + max_length [batch_size, 1]: maximum length of the path for each customer + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 1.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + depot_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + min_prize: float = 1.0, + max_prize: float = 1.0, + prize_distribution: Union[ + int, float, type, Callable + ] = Uniform, + prize_type: str = "dist", + max_length: Union[float, torch.Tensor] = None, + **kwargs + ): + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.min_prize = min_prize + self.max_prize = max_prize + self.prize_type = prize_type + self.max_length = max_length + + # Location distribution + if kwargs.get("loc_sampler", None) is not None: + self.loc_sampler = kwargs["loc_sampler"] + else: + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) + + # Depot distribution + if kwargs.get("depot_sampler", None) is not None: + self.depot_sampler = kwargs["depot_sampler"] + else: + self.depot_sampler = get_sampler("depot", depot_distribution, min_loc, max_loc, **kwargs) + + # Prize distribution + if kwargs.get("prize_sampler", None) is not None: + self.prize_sampler = kwargs["prize_sampler"] + elif prize_distribution == 'dist': # If prize_distribution is 'dist', then the prize is the distance from the depot + self.prize_sampler = None + else: + self.prize_sampler = get_sampler("prize", prize_distribution, min_prize, max_prize, **kwargs) + + # Max length + if max_length is not None: + self.max_length = max_length + else: + self.max_length = MAX_LENGTHS.get(num_loc, None) + if self.max_length is None: + closest_num_loc = min(MAX_LENGTHS.keys(), key=lambda x: abs(x - num_loc)) + self.max_length = MAX_LENGTHS[closest_num_loc] + log.warning( + f"The max length for {num_loc} locations is not defined. Using the closest max length: {self.max_length}\ + with {closest_num_loc} locations." + ) + + def _generate(self, batch_size) -> TensorDict: + # Sample locations + locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2)) + + # Sample depot + depot = self.depot_sampler.sample((*batch_size, 2)) + + locs_with_depot = torch.cat((depot.unsqueeze(1), locs), dim=1) + + # Methods taken from Fischetti et al. (1998) and Kool et al. (2019) + if self.prize_type == "const": + prize = torch.ones(*batch_size, self.num_loc, device=self.device) + elif self.prize_type == "unif": + prize = ( + 1 + + torch.randint( + 0, 100, (*batch_size, self.num_loc), device=self.device + ).float() + ) / 100 + elif self.prize_type == "dist": # based on the distance to the depot + prize = (locs_with_depot[..., 0:1, :] - locs_with_depot[..., 1:, :]).norm( + p=2, dim=-1 + ) + prize = ( + 1 + (prize / prize.max(dim=-1, keepdim=True)[0] * 99).int() + ).float() / 100 + else: + raise ValueError(f"Invalid prize_type: {self.prize_type}") + + # Support for heterogeneous max length if provided + if not isinstance(self.max_length, torch.Tensor): + max_length = torch.full((*batch_size,), self.max_length) + else: + max_length = self.max_length + + return TensorDict( + { + "locs": locs_with_depot[..., 1:, :], + "depot": locs_with_depot[..., 0, :], + "prize": prize, + "max_length": max_length, + }, + batch_size=batch_size, + ) diff --git a/rl4co/envs/routing/op/render.py b/rl4co/envs/routing/op/render.py new file mode 100644 index 00000000..65ad40be --- /dev/null +++ b/rl4co/envs/routing/op/render.py @@ -0,0 +1,86 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import cm, colormaps + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td, actions=None, ax=None): + # Create a plot of the nodes + if ax is None: + _, ax = plt.subplots() + + td = td.detach().cpu() + + # Actions + if actions is None: + actions = td.get("action", None) + actions = actions.detach().cpu() if actions is not None else None + + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] if actions is not None else None + + # Variables + depot = td["locs"][0, :] + customers = td["locs"][1:, :] + prizes = td["prize"][1:] + normalized_prizes = ( + 200 * (prizes - torch.min(prizes)) / (torch.max(prizes) - torch.min(prizes)) + + 10 + ) + + # Plot depot and customers with prize + ax.scatter( + depot[0], + depot[1], + marker="s", + c="tab:green", + edgecolors="black", + zorder=5, + s=100, + ) # Plot depot as square + ax.scatter( + customers[:, 0], + customers[:, 1], + s=normalized_prizes, + c=normalized_prizes, + cmap="autumn_r", + alpha=0.6, + edgecolors="black", + ) # Plot all customers with size and color indicating the prize + + # Gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + # Reorder the customers and their corresponding prizes based on actions + tour = customers[actions - 1] # subtract 1 to match Python's 0-indexing + + # Append the depot at the beginning and the end of the tour + tour = np.vstack((depot, tour, depot)) + + # Use quiver to plot the tour + dx, dy = np.diff(tour[:, 0]), np.diff(tour[:, 1]) + ax.quiver( + tour[:-1, 0], + tour[:-1, 1], + dx, + dy, + scale_units="xy", + angles="xy", + scale=1, + zorder=2, + color="black", + width=0.0035, + ) + + # Setup limits and show + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/pctsp.py b/rl4co/envs/routing/pctsp/env.py similarity index 50% rename from rl4co/envs/routing/pctsp.py rename to rl4co/envs/routing/pctsp/env.py index 24ac66ec..521a3b05 100644 --- a/rl4co/envs/routing/pctsp.py +++ b/rl4co/envs/routing/pctsp/env.py @@ -15,16 +15,10 @@ from rl4co.utils.ops import gather_by_index, get_tour_length from rl4co.utils.pylogger import get_pylogger -log = get_pylogger(__name__) +from .generator import PCTSPGenerator +from .render import render -# For the penalty to make sense it should be not too large (in which case all nodes will be visited) nor too small -# so we want the objective term to be approximately equal to the length of the tour, which we estimate with half -# of the nodes by half of the tour length (which is very rough but similar to op) -# This means that the sum of penalties for all nodes will be approximately equal to the tour length (on average) -# The expected total (uniform) penalty of half of the nodes (since approx half will be visited by the constraint) -# is (n / 2) / 2 = n / 4 so divide by this means multiply by 4 / n, -# However instead of 4 we use penalty_factor (3 works well) so we can make them larger or smaller -MAX_LENGTHS = {20: 2.0, 50: 3.0, 100: 4.0} +log = get_pylogger(__name__) class PCTSPEnv(RL4COEnvBase): @@ -32,14 +26,29 @@ class PCTSPEnv(RL4COEnvBase): The goal is to collect as much prize as possible while minimizing the total travel cost. The environment is stochastic, the prize is only revealed when the node is visited. + Observations: + - locations of the nodes + - prize and penalty of each node + - current location of the vehicle + - current total prize + - current total penalty + - visited nodes + - prize required to visit a node + - the current step + + Constraints: + - the tour starts and ends at the depot + - the vehicle cannot visit nodes exceed the remaining prize + + Finish Condition: + - the vehicle back to the depot + + Reward: + - the sum of the saved penalties + Args: - num_loc: Number of locations - min_loc: Minimum location value - max_loc: Maximum location value - penalty_factor: Penalty factor - prize_required: Minimum prize required to visit a node - check_solution: Set to False by default for small bug happening around 0.01% of the time (TODO: fix) - td_params: Parameters of the environment + generator: OPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "pctsp" @@ -47,22 +56,15 @@ class PCTSPEnv(RL4COEnvBase): def __init__( self, - num_loc: int = 10, - min_loc: float = 0, - max_loc: float = 1, - penalty_factor: float = 3, - prize_required: float = 1, - check_solution: bool = False, - td_params: TensorDict = None, + generator: PCTSPGenerator = None, + generator_params: dict = {}, **kwargs, ): super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self.penalty_factor = penalty_factor - self.prize_required = prize_required - self.check_solution = check_solution + if generator is None: + generator = PCTSPGenerator(**generator_params) + self.generator = generator + self._make_spec(self.generator) def _step(self, td: TensorDict) -> TensorDict: current_node = td["action"] @@ -102,11 +104,7 @@ def _step(self, td: TensorDict) -> TensorDict: def _reset( self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - self.to(td.device) + device = td.device locs = torch.cat([td["depot"][..., None, :], td["locs"]], dim=-2) expected_prize = td["deterministic_prize"] @@ -122,19 +120,19 @@ def _reset( penalty_with_depot = F.pad(penalty, (1, 0), mode="constant", value=0) # Initialize the current node and prize / penalty - current_node = torch.zeros((*batch_size,), dtype=torch.int64, device=self.device) - cur_total_prize = torch.zeros(*batch_size, device=self.device) + current_node = torch.zeros((*batch_size,), dtype=torch.int64, device=device) + cur_total_prize = torch.zeros(*batch_size, device=device) cur_total_penalty = penalty.sum(-1)[ :, None ] # Sum penalties (all when nothing is visited), add step dim # Init the action mask (all nodes are available) visited = torch.zeros( - (*batch_size, self.num_loc + 1), dtype=torch.bool, device=self.device + (*batch_size, self.generator.num_loc + 1), dtype=torch.bool, device=device ) - i = torch.zeros((*batch_size,), dtype=torch.int64, device=self.device) + i = torch.zeros((*batch_size,), dtype=torch.int64, device=device) prize_required = torch.full( - (*batch_size,), self.prize_required, device=self.device + (*batch_size,), self.generator.prize_required, device=device ) td_reset = TensorDict( @@ -164,7 +162,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: ) return ~(mask > 0) # Invert mask, since 1 means feasible action - def get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: + def _get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: """Reward is `saved penalties - (total length + penalty)`""" # In case all tours directly return to depot, prevent further problems @@ -172,10 +170,6 @@ def get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: assert (actions == 0).all(), "If all length 1 tours, they should be zero" return torch.zeros(actions.size(0), dtype=torch.float, device=actions.device) - # Check that the solution is valid - if self.check_solution: - self.check_solution_validity(td, actions) - # Gather locations in order of tour and get the length of tours locs_ordered = gather_by_index(td["locs"], actions) length = get_tour_length(locs_ordered) @@ -210,52 +204,6 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): ) # no depot ).all(), "Total prize does not satisfy min total prize" - def generate_data(self, batch_size) -> TensorDict: - # Batch size input check - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - depot = torch.rand((*batch_size, 2)) - locs = torch.rand((*batch_size, self.num_loc, 2)) - - penalty_max = ( - MAX_LENGTHS[self.num_loc] * (self.penalty_factor) / float(self.num_loc) - ) - penalty = torch.rand((*batch_size, self.num_loc)) * penalty_max - - # Take uniform prizes - # Now expectation is 0.5 so expected total prize is n / 2, we want to force to visit approximately half of the nodes - # so the constraint will be that total prize >= (n / 2) / 2 = n / 4 - # equivalently, we divide all prizes by n / 4 and the total prize should be >= 1 - deterministic_prize = ( - torch.rand((*batch_size, self.num_loc)) * 4 / float(self.num_loc) - ) - - # In the deterministic setting, the stochastic_prize is not used and the deterministic prize is known - # In the stochastic setting, the deterministic prize is the expected prize and is known up front but the - # stochastic prize is only revealed once the node is visited - # Stochastic prize is between (0, 2 * expected_prize) such that E(stochastic prize) = E(deterministic_prize) - stochastic_prize = ( - torch.rand((*batch_size, self.num_loc)) * deterministic_prize * 2 - ) - # In the deterministic setting, the stochastic_prize is not used and the deterministic prize is known - # In the stochastic setting, the deterministic prize is the expected prize and is known up front but the - # stochastic prize is only revealed once the node is visited - # Stochastic prize is between (0, 2 * expected_prize) such that E(stochastic prize) = E(deterministic_prize) - stochastic_prize = ( - torch.rand((*batch_size, self.num_loc)) * deterministic_prize * 2 - ) - - return TensorDict( - { - "locs": locs, - "depot": depot, - "penalty": penalty, - "deterministic_prize": deterministic_prize, - "stochastic_prize": stochastic_prize, - }, - batch_size=batch_size, - ) - @property def stochastic(self): return self._stochastic @@ -267,13 +215,13 @@ def stochastic(self, state: bool): "Stochastic mode should not be used for PCTSP. Use SPCTSP instead." ) - def _make_spec(self, td_params: TensorDict): + def _make_spec(self, generator): """Make the locs and action specs from the parameters.""" self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc, 2), + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.num_loc, 2), dtype=torch.float32, ), current_node=UnboundedDiscreteTensorSpec( @@ -281,15 +229,15 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.int64, ), expected_prize=UnboundedContinuousTensorSpec( - shape=(self.num_loc), + shape=(generator.num_loc), dtype=torch.float32, ), real_prize=UnboundedContinuousTensorSpec( - shape=(self.num_loc + 1), + shape=(generator.num_loc + 1), dtype=torch.float32, ), penalty=UnboundedContinuousTensorSpec( - shape=(self.num_loc + 1), + shape=(generator.num_loc + 1), dtype=torch.float32, ), cur_total_prize=UnboundedContinuousTensorSpec( @@ -301,7 +249,7 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.float32, ), visited=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1), + shape=(generator.num_loc + 1), dtype=torch.bool, ), prize_required=UnboundedContinuousTensorSpec( @@ -313,7 +261,7 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.int64, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc), + shape=(generator.num_loc), dtype=torch.bool, ), shape=(), @@ -322,99 +270,7 @@ def _make_spec(self, td_params: TensorDict): shape=(1,), dtype=torch.int64, low=0, - high=self.num_loc, + high=generator.num_loc, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - - @staticmethod - def render(td, actions=None, ax=None): - import matplotlib.pyplot as plt - import numpy as np - - from matplotlib import colormaps - - # Create a plot of the nodes - if ax is None: - _, ax = plt.subplots() - - td = td.detach().cpu() - - # Actions - if actions is None: - actions = td.get("action", None) - actions = actions.detach().cpu() if actions is not None else None - - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] if actions is not None else None - - # Variables - depot = td["locs"][0, :] - cities = td["locs"][1:, :] - prizes = td["real_prize"][1:] - penalties = td["penalty"][1:] - normalized_prizes = ( - 200 * (prizes - torch.min(prizes)) / (torch.max(prizes) - torch.min(prizes)) - + 10 - ) - normalized_penalties = ( - 3 - * (penalties - torch.min(penalties)) - / (torch.max(penalties) - torch.min(penalties)) - ) - - # Represent penalty with colormap and size of edges - penalty_cmap = colormaps.get_cmap("BuPu") - penalty_colors = penalty_cmap(normalized_penalties) - - # Plot depot and cities with prize (size of nodes) and penalties (size of borders) - ax.scatter( - depot[0], - depot[1], - marker="s", - c="tab:green", - edgecolors="black", - zorder=1, - s=100, - ) # Plot depot as square - ax.scatter( - cities[:, 0], - cities[:, 1], - s=normalized_prizes, - c=normalized_prizes, - cmap="autumn_r", - alpha=1, - edgecolors=penalty_colors, - linewidths=normalized_penalties, - ) # Plot all cities with size and color indicating the prize - - # Gather locs in order of action if available - if actions is None: - print("No action in TensorDict, rendering unsorted locs") - else: - # Reorder the cities and their corresponding prizes based on actions - tour = cities[actions - 1] # subtract 1 to match Python's 0-indexing - - # Append the depot at the beginning and the end of the tour - tour = np.vstack((depot, tour, depot)) - - # Use quiver to plot the tour - dx, dy = np.diff(tour[:, 0]), np.diff(tour[:, 1]) - ax.quiver( - tour[:-1, 0], - tour[:-1, 1], - dx, - dy, - scale_units="xy", - angles="xy", - scale=1, - zorder=2, - color="black", - width=0.0035, - ) - - # Setup limits and show - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/pctsp/generator.py b/rl4co/envs/routing/pctsp/generator.py new file mode 100644 index 00000000..5863fa50 --- /dev/null +++ b/rl4co/envs/routing/pctsp/generator.py @@ -0,0 +1,116 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + +# For the penalty to make sense it should be not too large (in which case all nodes will be visited) nor too small +# so we want the objective term to be approximately equal to the length of the tour, which we estimate with half +# of the nodes by half of the tour length (which is very rough but similar to op) +# This means that the sum of penalties for all nodes will be approximately equal to the tour length (on average) +# The expected total (uniform) penalty of half of the nodes (since approx half will be visited by the constraint) +# is (n / 2) / 2 = n / 4 so divide by this means multiply by 4 / n, +# However instead of 4 we use penalty_factor (3 works well) so we can make them larger or smaller +MAX_LENGTHS = {20: 2.0, 50: 3.0, 100: 4.0} + + +class PCTSPGenerator(Generator): + """Data generator for the Prize-collecting Traveling Salesman Problem (PCTSP). + Args: + num_loc: number of locations (customers) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + loc_distribution: distribution for the location coordinates + depot_distribution: distribution for the depot location + min_demand: minimum value for the demand of each customer + max_demand: maximum value for the demand of each customer + demand_distribution: distribution for the demand of each customer + capacity: capacity of the vehicle + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each city + depot [batch_size, 2]: location of the depot + demand [batch_size, num_loc]: demand of each customer + capacity [batch_size, 1]: capacity of the vehicle + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 1.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + depot_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + penalty_factor: float = 3.0, + prize_required: float = 1.0, + **kwargs + ): + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.penalty_fctor = penalty_factor + self.prize_required = prize_required + + # Location distribution + if kwargs.get("loc_sampler", None) is not None: + self.loc_sampler = kwargs["loc_sampler"] + else: + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) + + # Depot distribution + if kwargs.get("depot_sampler", None) is not None: + self.depot_sampler = kwargs["depot_sampler"] + else: + self.depot_sampler = get_sampler("depot", depot_distribution, min_loc, max_loc, **kwargs) + + # Prize distribution + self.deterministic_prize_sampler = get_sampler("deterministric_prize", "uniform", 0.0, 4.0/self.num_loc, **kwargs) + self.stochastic_prize_sampler = get_sampler("stochastic_prize", "uniform", 0.0, 8.0/self.num_loc, **kwargs) + + # Penalty + self.max_penalty = kwargs.get("max_penalty", None) + if self.max_penalty is None: # If not provided, use the default max penalty + self.max_penalty = MAX_LENGTHS.get(num_loc, None) + if self.max_penalty is None: # If not in the table keys, find the closest number of nodes as the key + closest_num_loc = min(MAX_LENGTHS.keys(), key=lambda x: abs(x - num_loc)) + self.max_penalty = MAX_LENGTHS[closest_num_loc] + log.warning( + f"The max penalty for {num_loc} locations is not defined. Using the closest max penalty: {self.max_penalty}\ + with {closest_num_loc} locations." + ) + self.penalty_sampler = get_sampler("penalty", "uniform", 0.0, self.max_penalty, **kwargs) + + def _generate(self, batch_size) -> TensorDict: + # Sample locations + locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2)) + + # Sample depot + depot = self.depot_sampler.sample((*batch_size, 2)) + + # Sample penalty + penalty = self.penalty_sampler.sample((*batch_size, self.num_loc)) + + # Sampel prize + deterministic_prize = self.deterministic_prize_sampler.sample((*batch_size, self.num_loc)) + stochastic_prize = self.stochastic_prize_sampler.sample((*batch_size, self.num_loc)) + + return TensorDict( + { + "locs": locs, + "depot": depot, + "penalty": penalty, + "deterministic_prize": deterministic_prize, + "stochastic_prize": stochastic_prize, + }, + batch_size=batch_size, + ) diff --git a/rl4co/envs/routing/pctsp/render.py b/rl4co/envs/routing/pctsp/render.py new file mode 100644 index 00000000..7d0e3622 --- /dev/null +++ b/rl4co/envs/routing/pctsp/render.py @@ -0,0 +1,93 @@ +import torch + + +def render(td, actions=None, ax=None): + import matplotlib.pyplot as plt + import numpy as np + + from matplotlib import colormaps + + # Create a plot of the nodes + if ax is None: + _, ax = plt.subplots() + + td = td.detach().cpu() + + # Actions + if actions is None: + actions = td.get("action", None) + actions = actions.detach().cpu() if actions is not None else None + + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] if actions is not None else None + + # Variables + depot = td["locs"][0, :] + customers = td["locs"][1:, :] + prizes = td["real_prize"][1:] + penalties = td["penalty"][1:] + normalized_prizes = ( + 200 * (prizes - torch.min(prizes)) / (torch.max(prizes) - torch.min(prizes)) + + 10 + ) + normalized_penalties = ( + 3 + * (penalties - torch.min(penalties)) + / (torch.max(penalties) - torch.min(penalties)) + ) + + # Represent penalty with colormap and size of edges + penalty_cmap = colormaps.get_cmap("BuPu") + penalty_colors = penalty_cmap(normalized_penalties) + + # Plot depot and customers with prize (size of nodes) and penalties (size of borders) + ax.scatter( + depot[0], + depot[1], + marker="s", + c="tab:green", + edgecolors="black", + zorder=1, + s=100, + ) # Plot depot as square + ax.scatter( + customers[:, 0], + customers[:, 1], + s=normalized_prizes, + c=normalized_prizes, + cmap="autumn_r", + alpha=1, + edgecolors=penalty_colors, + linewidths=normalized_penalties, + ) # Plot all customers with size and color indicating the prize + + # Gather locs in order of action if available + if actions is None: + print("No action in TensorDict, rendering unsorted locs") + else: + # Reorder the customers and their corresponding prizes based on actions + tour = customers[actions - 1] # subtract 1 to match Python's 0-indexing + + # Append the depot at the beginning and the end of the tour + tour = np.vstack((depot, tour, depot)) + + # Use quiver to plot the tour + dx, dy = np.diff(tour[:, 0]), np.diff(tour[:, 1]) + ax.quiver( + tour[:-1, 0], + tour[:-1, 1], + dx, + dy, + scale_units="xy", + angles="xy", + scale=1, + zorder=2, + color="black", + width=0.0035, + ) + + # Setup limits and show + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/pdp.py b/rl4co/envs/routing/pdp/env.py similarity index 55% rename from rl4co/envs/routing/pdp.py rename to rl4co/envs/routing/pdp/env.py index a23d086e..a7959385 100644 --- a/rl4co/envs/routing/pdp.py +++ b/rl4co/envs/routing/pdp/env.py @@ -13,6 +13,9 @@ from rl4co.envs.common.base import RL4COEnvBase from rl4co.utils.ops import gather_by_index, get_tour_length +from .generator import PDPGenerator +from .render import render + class PDPEnv(RL4COEnvBase): """Pickup and Delivery Problem (PDP) environment. @@ -23,28 +26,42 @@ class PDPEnv(RL4COEnvBase): The goal is to visit all the pickup and delivery locations in the shortest path possible starting from the depot The conditions is that the agent must visit a pickup location before visiting its corresponding delivery location + Observations: + - locations of the depot, pickup, and delivery locations + - current location of the vehicle + - the remaining locations to deliver + - the visited locations + - the current step + + Constraints: + - the tour starts and ends at the depot + - each pickup location must be visited before its corresponding delivery location + - the vehicle cannot visit the same location twice + + Finish Condition: + - the vehicle has visited all locations + + Reward: + - (minus) the negative length of the path + Args: - num_loc: number of locations (cities) in the TSP - td_params: parameters of the environment - seed: seed for the environment - device: device to use. Generally, no need to set as tensors are updated on the fly + generator: PDPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "pdp" def __init__( self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - td_params: TensorDict = None, + generator: PDPGenerator = None, + generator_params: dict = {}, **kwargs, ): super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self._make_spec(td_params) + if generator is None: + generator = PDPGenerator(**generator_params) + self.generator = generator + self._make_spec(self.generator) @staticmethod def _step(td: TensorDict) -> TensorDict: @@ -89,13 +106,7 @@ def _step(td: TensorDict) -> TensorDict: return td def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td.batch_size - - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - - self.to(td.device) + device = td.device locs = torch.cat((td["depot"][:, None, :], td["locs"]), -2) @@ -104,12 +115,12 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict [ torch.ones( *batch_size, - self.num_loc // 2 + 1, + self.generator.num_loc // 2 + 1, dtype=torch.bool, - device=self.device, + device=device, ), torch.zeros( - *batch_size, self.num_loc // 2, dtype=torch.bool, device=self.device + *batch_size, self.generator.num_loc // 2, dtype=torch.bool, device=device ), ], dim=-1, @@ -117,16 +128,16 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict # Cannot visit depot at first step # [0,1...1] so set not available available = torch.ones( - (*batch_size, self.num_loc + 1), dtype=torch.bool, device=self.device + (*batch_size, self.generator.num_loc + 1), dtype=torch.bool, device=device ) action_mask = ~available.contiguous() # [batch_size, graph_size+1] action_mask[..., 0] = 1 # First step is always the depot # Other variables current_node = torch.zeros( - (*batch_size, 1), dtype=torch.int64, device=self.device + (*batch_size, 1), dtype=torch.int64, device=device ) - i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=self.device) + i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) return TensorDict( { @@ -140,13 +151,13 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size=batch_size, ) - def _make_spec(self, td_params: TensorDict): + def _make_spec(self, generator: PDPGenerator): """Make the observation and action specs from the parameters.""" self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc + 1, 2), + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.num_loc + 1, 2), dtype=torch.float32, ), current_node=UnboundedDiscreteTensorSpec( @@ -162,7 +173,7 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.int64, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1), + shape=(generator.num_loc + 1), dtype=torch.bool, ), shape=(), @@ -171,13 +182,18 @@ def _make_spec(self, td_params: TensorDict): shape=(1,), dtype=torch.int64, low=0, - high=self.num_loc + 1, + high=generator.num_loc + 1, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) @staticmethod - def get_reward(td, actions) -> TensorDict: + def _get_reward(td, actions) -> TensorDict: + # Gather locations in the order of actions and get reward = -(total distance) + locs_ordered = gather_by_index(td["locs"], actions) # [batch, graph_size+1, 2] + return -get_tour_length(locs_ordered) + + def check_solution_validity(self, td, actions): # assert (actions[:, 0] == 0).all(), "Not starting at depot" assert ( torch.arange(actions.size(1), out=actions.data.new()) @@ -193,100 +209,3 @@ def get_reward(td, actions) -> TensorDict: visited_time[:, 1 : actions.size(1) // 2 + 1] < visited_time[:, actions.size(1) // 2 + 1 :] ).all(), "Deliverying without pick-up" - - # Gather locations in the order of actions and get reward = -(total distance) - locs_ordered = gather_by_index(td["locs"], actions) # [batch, graph_size+1, 2] - return -get_tour_length(locs_ordered) - - def generate_data(self, batch_size) -> TensorDict: - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - # Initialize the locations (including the depot which is always the first node) - locs_with_depot = ( - torch.FloatTensor(*batch_size, self.num_loc + 1, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - return TensorDict( - { - "locs": locs_with_depot[..., 1:, :], - "depot": locs_with_depot[..., 0, :], - }, - batch_size=batch_size, - ) - - @staticmethod - def render(td: TensorDict, actions=None, ax=None): - import matplotlib.pyplot as plt - - markersize = 8 - - td = td.detach().cpu() - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - if actions is not None: - actions = actions[0] - - # Variables - init_deliveries = td["to_deliver"][1:] - delivery_locs = td["locs"][1:][~init_deliveries.bool()] - pickup_locs = td["locs"][1:][init_deliveries.bool()] - depot_loc = td["locs"][0] - actions = actions if actions is not None else td["action"] - - fig, ax = plt.subplots() - - # Plot the actions in order - for i in range(len(actions)): - from_node = actions[i] - to_node = ( - actions[i + 1] if i < len(actions) - 1 else actions[0] - ) # last goes back to depot - from_loc = td["locs"][from_node] - to_loc = td["locs"][to_node] - ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], "k-") - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="->", color="black"), - annotation_clip=False, - ) - - # Plot the depot location - ax.plot( - depot_loc[0], - depot_loc[1], - "g", - marker="s", - markersize=markersize, - label="Depot", - ) - - # Plot the pickup locations - for i, pickup_loc in enumerate(pickup_locs): - ax.plot( - pickup_loc[0], - pickup_loc[1], - "r", - marker="^", - markersize=markersize, - label="Pickup" if i == 0 else None, - ) - - # Plot the delivery locations - for i, delivery_loc in enumerate(delivery_locs): - ax.plot( - delivery_loc[0], - delivery_loc[1], - "b", - marker="v", - markersize=markersize, - label="Delivery" if i == 0 else None, - ) - - # Setup limits and show - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/pdp/generator.py b/rl4co/envs/routing/pdp/generator.py new file mode 100644 index 00000000..df96cbdc --- /dev/null +++ b/rl4co/envs/routing/pdp/generator.py @@ -0,0 +1,78 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class PDPGenerator(Generator): + """Data generator for the Pickup and Delivery Problem (PDP). + Args: + num_loc: number of locations (customers) in the PDP, without the depot. (e.g. 10 means 10 locs + 1 depot) + - 1 depot + - `num_loc` / 2 pickup locations + - `num_loc` / 2 delivery locations + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + loc_distribution: distribution for the location coordinates + depot_distribution: distribution for the depot location + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + depot [batch_size, 2]: location of the depot + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 1.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + depot_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + **kwargs + ): + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + + # Number of locations must be even + if num_loc % 2 != 0: + log.warn("Number of locations must be even. Adding 1 to the number of locations.") + self.num_loc += 1 + + # Location distribution + if kwargs.get("loc_sampler", None) is not None: + self.loc_sampler = kwargs["loc_sampler"] + else: + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) + + # Depot distribution + if kwargs.get("depot_sampler", None) is not None: + self.depot_sampler = kwargs["depot_sampler"] + else: + self.depot_sampler = get_sampler("depot", depot_distribution, min_loc, max_loc, **kwargs) + + def _generate(self, batch_size) -> TensorDict: + # Sample locations + locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2)) + + # Sample depot + depot = self.depot_sampler.sample((*batch_size, 2)) + + return TensorDict( + { + "locs": locs, + "depot": depot, + }, + batch_size=batch_size, + ) diff --git a/rl4co/envs/routing/pdp/render.py b/rl4co/envs/routing/pdp/render.py new file mode 100644 index 00000000..4a3454b3 --- /dev/null +++ b/rl4co/envs/routing/pdp/render.py @@ -0,0 +1,83 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import cm, colormaps + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td, actions=None, ax=None): + markersize = 8 + + td = td.detach().cpu() + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + if actions is not None: + actions = actions[0] + + # Variables + init_deliveries = td["to_deliver"][1:] + delivery_locs = td["locs"][1:][~init_deliveries.bool()] + pickup_locs = td["locs"][1:][init_deliveries.bool()] + depot_loc = td["locs"][0] + actions = actions if actions is not None else td["action"] + + fig, ax = plt.subplots() + + # Plot the actions in order + for i in range(len(actions)): + from_node = actions[i] + to_node = ( + actions[i + 1] if i < len(actions) - 1 else actions[0] + ) # last goes back to depot + from_loc = td["locs"][from_node] + to_loc = td["locs"][to_node] + ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], "k-") + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="->", color="black"), + annotation_clip=False, + ) + + # Plot the depot location + ax.plot( + depot_loc[0], + depot_loc[1], + "g", + marker="s", + markersize=markersize, + label="Depot", + ) + + # Plot the pickup locations + for i, pickup_loc in enumerate(pickup_locs): + ax.plot( + pickup_loc[0], + pickup_loc[1], + "r", + marker="^", + markersize=markersize, + label="Pickup" if i == 0 else None, + ) + + # Plot the delivery locations + for i, delivery_loc in enumerate(delivery_locs): + ax.plot( + delivery_loc[0], + delivery_loc[1], + "b", + marker="v", + markersize=markersize, + label="Delivery" if i == 0 else None, + ) + + # Setup limits and show + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/routing/sdvrp.py b/rl4co/envs/routing/sdvrp/env.py similarity index 73% rename from rl4co/envs/routing/sdvrp.py rename to rl4co/envs/routing/sdvrp/env.py index dcd90906..055ca7de 100644 --- a/rl4co/envs/routing/sdvrp.py +++ b/rl4co/envs/routing/sdvrp/env.py @@ -13,7 +13,8 @@ from rl4co.utils.ops import gather_by_index from rl4co.utils.pylogger import get_pylogger -from .cvrp import CVRPEnv +from ..cvrp.env import CVRPEnv +from ..cvrp.generator import CVRPGenerator log = get_pylogger(__name__) @@ -23,45 +24,41 @@ class SDVRPEnv(CVRPEnv): SDVRP is a generalization of CVRP, where nodes can be visited multiple times and a fraction of the demand can be met. At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to - visit any customer, the agent must go back to the depot. The reward is the -infinite unless the agent visits all the cities. + visit any customer, the agent must go back to the depot. The reward is the -infinite unless the agent visits all the customers. In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. + Observations: + - location of the depot. + - locations and demand/remaining demand of each customer + - current location of the vehicle. + - the remaining capacity of the vehicle. + + Constraints: + - the tour starts and ends at the depot. + - each customer can be visited multiple times. + - the vehicle cannot visit customers exceed the remaining capacity. + - the vehicle can return to the depot to refill the capacity. + + Finish Condition: + - the vehicle has finished all customers demand and returned to the depot. + + Reward: + - (minus) the negative length of the path. + Args: - num_loc: number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) - min_loc: minimum value for the location coordinates - max_loc: maximum value for the location coordinates - min_demand: minimum value for the demand of each customer - max_demand: maximum value for the demand of each customer - vehicle_capacity: capacity of the vehicle - capacity: capacity of the vehicle - td_params: parameters of the environment + generator: CVRPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "sdvrp" def __init__( self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - min_demand: float = 1, - max_demand: float = 10, - vehicle_capacity: float = 1.0, - capacity: float = None, - td_params: TensorDict = None, + generator: CVRPGenerator = None, + generator_params: dict = {}, **kwargs, ): - super().__init__( - num_loc=num_loc, - min_loc=min_loc, - max_loc=max_loc, - min_demand=min_demand, - max_demand=max_demand, - vehicle_capacity=vehicle_capacity, - capacity=capacity, - td_params=td_params, - **kwargs, - ) + super().__init__(generator, generator_params, **kwargs) def _step(self, td: TensorDict) -> TensorDict: # Update the state @@ -109,13 +106,7 @@ def _reset( td: Optional[TensorDict] = None, batch_size: Optional[list] = None, ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - - self.to(td.device) + device = td.device # Create reset TensorDict reset_td = TensorDict( @@ -126,11 +117,11 @@ def _reset( (torch.zeros_like(td["demand"][..., 0:1]), td["demand"]), -1 ), "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device + *batch_size, 1, dtype=torch.long, device=device ), - "used_capacity": torch.zeros((*batch_size, 1), device=self.device), + "used_capacity": torch.zeros((*batch_size, 1), device=device), "vehicle_capacity": torch.full( - (*batch_size, 1), self.vehicle_capacity, device=self.device + (*batch_size, 1), self.generator.vehicle_capacity, device=device ), }, batch_size=batch_size, @@ -172,13 +163,13 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): a_prev = a assert (demands == 0).all(), "All demand must be satisfied" - def _make_spec(self, td_params: TensorDict): + def _make_spec(self, generator): """Make the observation and action specs from the parameters.""" self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc + 1, 2), + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.num_loc + 1, 2), dtype=torch.float32, ), current_node=UnboundedDiscreteTensorSpec( @@ -186,25 +177,25 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.int64, ), demand=BoundedTensorSpec( - low=self.min_demand, - high=self.max_demand, - shape=(self.num_loc, 1), # demand is only for customers + low=generator.min_demand, + high=generator.max_demand, + shape=(generator.num_loc, 1), # demand is only for customers dtype=torch.float32, ), demand_with_depot=BoundedTensorSpec( - low=self.min_demand, - high=self.max_demand, - shape=(self.num_loc + 1, 1), + low=generator.min_demand, + high=generator.max_demand, + shape=(generator.num_loc + 1, 1), dtype=torch.float32, ), used_capacity=BoundedTensorSpec( low=0, - high=self.vehicle_capacity, + high=generator.vehicle_capacity, shape=(1,), dtype=torch.float32, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1, 1), + shape=(generator.num_loc + 1, 1), dtype=torch.bool, ), shape=(), @@ -213,7 +204,7 @@ def _make_spec(self, td_params: TensorDict): shape=(1,), dtype=torch.int64, low=0, - high=self.num_loc + 1, + high=generator.num_loc + 1, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) diff --git a/rl4co/envs/routing/spctsp.py b/rl4co/envs/routing/spctsp/env.py similarity index 95% rename from rl4co/envs/routing/spctsp.py rename to rl4co/envs/routing/spctsp/env.py index e39a31b1..4f99c070 100644 --- a/rl4co/envs/routing/spctsp.py +++ b/rl4co/envs/routing/spctsp/env.py @@ -1,6 +1,6 @@ from rl4co.utils.pylogger import get_pylogger -from .pctsp import PCTSPEnv +from ..pctsp.env import PCTSPEnv log = get_pylogger(__name__) diff --git a/rl4co/envs/routing/svrp.py b/rl4co/envs/routing/svrp/env.py similarity index 56% rename from rl4co/envs/routing/svrp.py rename to rl4co/envs/routing/svrp/env.py index cab0bc8c..1b083a59 100644 --- a/rl4co/envs/routing/svrp.py +++ b/rl4co/envs/routing/svrp/env.py @@ -15,11 +15,14 @@ from rl4co.utils.ops import gather_by_index, get_distance from rl4co.utils.pylogger import get_pylogger +from .generator import SVRPGenerator +from .render import render + log = get_pylogger(__name__) class SVRPEnv(RL4COEnvBase): - """ + """Skill-Vehicle Routing Problem (SVRP) environment. Basic Skill-VRP environment. The environment is a variant of the Capacitated Vehicle Routing Problem (CVRP). Each technician has a certain skill-level and each customer node requires a certain skill-level to be serviced. Each customer node needs is to be serviced by exactly one technician. Technicians can only service nodes if @@ -27,46 +30,51 @@ class SVRPEnv(RL4COEnvBase): the goal is to minimize the total travel cost of the technicians. The travel cost depends on the skill-level of the technician. The environment is defined by the following parameters: + Observations: + - locations of the depot, pickup, and delivery locations + - current location of the vehicle + - the remaining locations to deliver + - the visited locations + - the current step + + Constraints: + - the tour starts and ends at the depot + - each pickup location must be visited before its corresponding delivery location + - the vehicle cannot visit the same location twice + + Finish Condition: + - the vehicle has visited all locations + + Reward: + - (minus) the negative length of the path + Args: - num_loc (int): Number of customer locations. Default: 20 - min_loc (float): Minimum value for the location coordinates. Default: 0 - max_loc (float): Maximum value for the location coordinates. Default: 1 - min_skill (float): Minimum skill level of the technicians. Default: 1 - max_skill (float): Maximum skill level of the technicians. Default: 10 - tech_costs (list): List of travel costs for the technicians. Default: [1, 2, 3]. The number of entries in this list determines the number of available technicians. - td_params (TensorDict): Parameters for the TensorDict. Default: None + generator: PDPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "svrp" def __init__( self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - min_skill: float = 1, - max_skill: float = 10, - tech_costs: list = [1, 2, 3], - td_params: TensorDict = None, + generator: SVRPGenerator = None, + generator_params: dict = {}, **kwargs, ): super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self.min_skill = min_skill - self.max_skill = max_skill - self.tech_costs = tech_costs - self.num_tech = len(tech_costs) - self._make_spec(td_params) + if generator is None: + generator = SVRPGenerator(**generator_params) + self.generator = generator + self.tech_costs = self.generator.tech_costs + self._make_spec(self.generator) - def _make_spec(self, td_params: TensorDict = None): + def _make_spec(self, generator): """Make the observation and action specs from the parameters.""" self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc + 1, 2), + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.num_loc + 1, 2), dtype=torch.float32, ), current_node=UnboundedDiscreteTensorSpec( @@ -74,13 +82,13 @@ def _make_spec(self, td_params: TensorDict = None): dtype=torch.int64, ), skills=BoundedTensorSpec( - low=self.min_skill, - high=self.max_skill, - shape=(self.num_loc, 1), + low=generator.min_skill, + high=generator.max_skill, + shape=(generator.num_loc, 1), dtype=torch.float32, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1, 1), + shape=(generator.num_loc + 1, 1), dtype=torch.bool, ), shape=(), @@ -89,51 +97,11 @@ def _make_spec(self, td_params: TensorDict = None): shape=(1,), dtype=torch.int64, low=0, - high=self.num_loc + 1, + high=generator.num_loc + 1, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,), dtype=torch.float32) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - def generate_data(self, batch_size): - """Generate data for the basic Skill-VRP. The data consists of the locations of the customers, - the skill-levels of the technicians and the required skill-levels of the customers. - The data is generated randomly within the given bounds.""" - # Batch size input check - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - # Initialize the locations (including the depot which is always the first node) - locs_with_depot = ( - torch.FloatTensor(*batch_size, self.num_loc + 1, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Initialize technicians and sort ascendingly - techs, _ = torch.sort( - torch.FloatTensor(*batch_size, self.num_tech, 1) - .uniform_(self.min_skill, self.max_skill) - .to(self.device), - dim=-2, - ) - - # Initialize the skills - skills = ( - torch.FloatTensor(*batch_size, self.num_loc, 1).uniform_(0, 1).to(self.device) - ) - # scale skills - skills = torch.max(techs, dim=1, keepdim=True).values * skills - td = TensorDict( - { - "locs": locs_with_depot[..., 1:, :], - "depot": locs_with_depot[..., 0, :], - "techs": techs, - "skills": skills, - }, - batch_size=batch_size, - device=self.device, - ) - return td - @staticmethod def get_action_mask(td: TensorDict) -> torch.Tensor: """Calculates the action mask for the Skill-VRP. The action mask is a binary mask that indicates which customer nodes can be services, given the previous decisions. @@ -186,13 +154,7 @@ def _step(self, td: TensorDict) -> torch.Tensor: def _reset( self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[0] - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - self.to(td.device) + device = td.device # Create reset TensorDict td_reset = TensorDict( @@ -201,15 +163,15 @@ def _reset( "techs": td["techs"], "skills": td["skills"], "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device + *batch_size, 1, dtype=torch.long, device=device ), "current_tech": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device + *batch_size, 1, dtype=torch.long, device=device ), "visited": torch.zeros( (*batch_size, td["locs"].shape[-2] + 1, 1), dtype=torch.uint8, - device=self.device, + device=device, ), }, batch_size=batch_size, @@ -217,7 +179,7 @@ def _reset( td_reset.set("action_mask", self.get_action_mask(td_reset)) return td_reset - def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: + def _get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: """Calculated the reward, where the reward is the negative total travel cost of the technicians. The travel cost depends on the skill-level of the technician.""" # Check that the solution is valid @@ -291,107 +253,3 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): ).all(), "Skill level not met" start = each[1] + 1 # skip the depot tech += 1 - - @staticmethod - def render( - td: TensorDict, - actions=None, - ax=None, - **kwargs, - ): - import matplotlib.pyplot as plt - import numpy as np - - from matplotlib import cm, colormaps - - num_routine = (actions == 0).sum().item() + 2 - base = colormaps["nipy_spectral"] - color_list = base(np.linspace(0, 1, num_routine)) - cmap_name = base.name + str(num_routine) - out = base.from_list(cmap_name, color_list, num_routine) - - if ax is None: - # Create a plot of the nodes - _, ax = plt.subplots() - - td = td.detach().cpu() - - if actions is None: - actions = td.get("action", None) - - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] - - locs = td["locs"] - - # add the depot at the first action and the end action - actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) - - # gather locs in order of action if available - if actions is None: - log.warning("No action in TensorDict, rendering unsorted locs") - else: - locs = locs - - # Cat the first node to the end to complete the tour - x, y = locs[:, 0], locs[:, 1] - - # plot depot - ax.scatter( - locs[0, 0], - locs[0, 1], - edgecolors=cm.Set2(2), - facecolors="none", - s=100, - linewidths=2, - marker="s", - alpha=1, - ) - - # plot visited nodes - ax.scatter( - x[1:], - y[1:], - edgecolors=cm.Set2(0), - facecolors="none", - s=50, - linewidths=2, - marker="o", - alpha=1, - ) - - # text depot - ax.text( - locs[0, 0], - locs[0, 1] - 0.025, - "Depot", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(2), - ) - - # plot actions - color_idx = 0 - for action_idx in range(len(actions) - 1): - if actions[action_idx] == 0: - color_idx += 1 - from_loc = locs[actions[action_idx]] - to_loc = locs[actions[action_idx + 1]] - ax.plot( - [from_loc[0], to_loc[0]], - [from_loc[1], to_loc[1]], - color=out(color_idx), - lw=1, - ) - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), - size=15, - annotation_clip=False, - ) - plt.show() diff --git a/rl4co/envs/routing/svrp/generator.py b/rl4co/envs/routing/svrp/generator.py new file mode 100644 index 00000000..55726b2f --- /dev/null +++ b/rl4co/envs/routing/svrp/generator.py @@ -0,0 +1,103 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class SVRPGenerator(Generator): + """Data generator for the Skill Vehicle Routing Problem (SVRP). + Args: + num_loc: number of locations (customers) in the TSP + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + loc_distribution: distribution for the location coordinates + min_skill: minimum value for the technic skill + max_skill: maximum value for the technic skill + skill_distribution: distribution for the technic skill + tech_costs: list of the technic costs + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + depot [batch_size, 2]: location of the depot + techs [batch_size, num_loc]: technic requirements of each customer + skills [batch_size, num_loc]: skills of the vehicles + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 1.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + depot_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + min_skill: float = 1.0, + max_skill: float = 10.0, + tech_costs: list = [1, 2, 3], + **kwargs + ): + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.min_skill = min_skill + self.max_skill = max_skill + self.num_tech = len(tech_costs) + self.tech_costs = torch.tensor(tech_costs) + + # Location distribution + if kwargs.get("loc_sampler", None) is not None: + self.loc_sampler = kwargs["loc_sampler"] + else: + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) + + # Depot distribution + if kwargs.get("depot_sampler", None) is not None: + self.depot_sampler = kwargs["depot_sampler"] + else: + self.depot_sampler = get_sampler("depot", depot_distribution, min_loc, max_loc, **kwargs) + + def _generate(self, batch_size) -> TensorDict: + """Generate data for the basic Skill-VRP. The data consists of the locations of the customers, + the skill-levels of the technicians and the required skill-levels of the customers. + The data is generated randomly within the given bounds.""" + # Sample locations + locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2)) + + # Sample depot + depot = self.depot_sampler.sample((*batch_size, 2)) + + locs_with_depot = torch.cat((depot[:, None, :], locs), dim=1) + + # Initialize technicians and sort ascendingly + techs, _ = torch.sort( + torch.FloatTensor(*batch_size, self.num_tech, 1) + .uniform_(self.min_skill, self.max_skill), + dim=-2, + ) + + # Initialize the skills + skills = ( + torch.FloatTensor(*batch_size, self.num_loc, 1).uniform_(0, 1) + ) + # scale skills + skills = torch.max(techs, dim=1, keepdim=True).values * skills + td = TensorDict( + { + "locs": locs_with_depot[..., 1:, :], + "depot": locs_with_depot[..., 0, :], + "techs": techs, + "skills": skills, + }, + batch_size=batch_size, + ) + return td diff --git a/rl4co/envs/routing/svrp/render.py b/rl4co/envs/routing/svrp/render.py new file mode 100644 index 00000000..88a3d752 --- /dev/null +++ b/rl4co/envs/routing/svrp/render.py @@ -0,0 +1,103 @@ +import torch + +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td, actions=None, ax=None): + import matplotlib.pyplot as plt + import numpy as np + + from matplotlib import cm, colormaps + + num_routine = (actions == 0).sum().item() + 2 + base = colormaps["nipy_spectral"] + color_list = base(np.linspace(0, 1, num_routine)) + cmap_name = base.name + str(num_routine) + out = base.from_list(cmap_name, color_list, num_routine) + + if ax is None: + # Create a plot of the nodes + _, ax = plt.subplots() + + td = td.detach().cpu() + + if actions is None: + actions = td.get("action", None) + + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + locs = td["locs"] + + # add the depot at the first action and the end action + actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) + + # gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + locs = locs + + # Cat the first node to the end to complete the tour + x, y = locs[:, 0], locs[:, 1] + + # plot depot + ax.scatter( + locs[0, 0], + locs[0, 1], + edgecolors=cm.Set2(2), + facecolors="none", + s=100, + linewidths=2, + marker="s", + alpha=1, + ) + + # plot visited nodes + ax.scatter( + x[1:], + y[1:], + edgecolors=cm.Set2(0), + facecolors="none", + s=50, + linewidths=2, + marker="o", + alpha=1, + ) + + # text depot + ax.text( + locs[0, 0], + locs[0, 1] - 0.025, + "Depot", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(2), + ) + + # plot actions + color_idx = 0 + for action_idx in range(len(actions) - 1): + if actions[action_idx] == 0: + color_idx += 1 + from_loc = locs[actions[action_idx]] + to_loc = locs[actions[action_idx + 1]] + ax.plot( + [from_loc[0], to_loc[0]], + [from_loc[1], to_loc[1]], + color=out(color_idx), + lw=1, + ) + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), + size=15, + annotation_clip=False, + ) diff --git a/rl4co/envs/routing/tsp.py b/rl4co/envs/routing/tsp/env.py similarity index 56% rename from rl4co/envs/routing/tsp.py rename to rl4co/envs/routing/tsp/env.py index 381eb8f3..eb1c3402 100644 --- a/rl4co/envs/routing/tsp.py +++ b/rl4co/envs/routing/tsp/env.py @@ -14,37 +14,49 @@ from rl4co.utils.ops import gather_by_index, get_tour_length from rl4co.utils.pylogger import get_pylogger +from .generator import TSPGenerator +from .render import render + log = get_pylogger(__name__) class TSPEnv(RL4COEnvBase): - """ - Traveling Salesman Problem environment + """Traveling Salesman Problem (TSP) environment At each step, the agent chooses a city to visit. The reward is 0 unless the agent visits all the cities. In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. + Observations: + - locations of each customer. + - the current location of the vehicle. + + Constrains: + - the tour must return to the starting customer. + - each customer must be visited exactly once. + + Finish condition: + - the agent has visited all customers and returned to the starting customer. + + Reward: + - (minus) the negative length of the path. + Args: - num_loc: number of locations (cities) in the TSP - td_params: parameters of the environment - seed: seed for the environment - device: device to use. Generally, no need to set as tensors are updated on the fly + generator: TSPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "tsp" def __init__( self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - td_params: TensorDict = None, + generator: TSPGenerator = None, + generator_params: dict = {}, **kwargs, ): super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self._make_spec(td_params) + if generator is None: + generator = TSPGenerator(**generator_params) + self.generator = generator + self._make_spec(self.generator) @staticmethod def _step(td: TensorDict) -> TensorDict: @@ -76,14 +88,8 @@ def _step(td: TensorDict) -> TensorDict: def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: # Initialize locations - init_locs = td["locs"] if td is not None else None - if batch_size is None: - batch_size = self.batch_size if init_locs is None else init_locs.shape[:-2] - device = init_locs.device if init_locs is not None else self.device - self.to(device) - if init_locs is None: - init_locs = self.generate_data(batch_size=batch_size).to(device)["locs"] - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + device = td.device + init_locs = td["locs"] # We do not enforce loading from self for flexibility num_loc = init_locs.shape[-2] @@ -107,13 +113,12 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size=batch_size, ) - def _make_spec(self, td_params): - """Make the observation and action specs from the parameters""" + def _make_spec(self, generator: TSPGenerator): self.observation_spec = CompositeSpec( locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc, 2), + low=generator.min_loc, + high=generator.max_loc, + shape=(generator.num_loc, 2), dtype=torch.float32, ), first_node=UnboundedDiscreteTensorSpec( @@ -129,21 +134,21 @@ def _make_spec(self, td_params): dtype=torch.int64, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc), + shape=(generator.num_loc), dtype=torch.bool, ), shape=(), ) self.action_spec = BoundedTensorSpec( - shape=(1,), + shape=(1), dtype=torch.int64, low=0, - high=self.num_loc, + high=generator.num_loc, ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) - self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1)) + self.done_spec = UnboundedDiscreteTensorSpec(shape=(1), dtype=torch.bool) - def get_reward(self, td, actions) -> TensorDict: + def _get_reward(self, td, actions) -> TensorDict: if self.check_solution: self.check_solution_validity(td, actions) @@ -161,55 +166,6 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): == actions.data.sort(1)[0] ).all(), "Invalid tour" - def generate_data(self, batch_size) -> TensorDict: - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - locs = ( - torch.rand((*batch_size, self.num_loc, 2), generator=self.rng) - * (self.max_loc - self.min_loc) - + self.min_loc - ) - return TensorDict({"locs": locs}, batch_size=batch_size) - @staticmethod - def render(td, actions=None, ax=None): - import matplotlib.pyplot as plt - import numpy as np - - if ax is None: - # Create a plot of the nodes - _, ax = plt.subplots() - - td = td.detach().cpu() - - if actions is None: - actions = td.get("action", None) - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] - - locs = td["locs"] - - # gather locs in order of action if available - if actions is None: - log.warning("No action in TensorDict, rendering unsorted locs") - else: - actions = actions.detach().cpu() - locs = gather_by_index(locs, actions, dim=0) - - # Cat the first node to the end to complete the tour - locs = torch.cat((locs, locs[0:1])) - x, y = locs[:, 0], locs[:, 1] - - # Plot the visited nodes - ax.scatter(x, y, color="tab:blue") - - # Add arrows between visited nodes as a quiver plot - dx, dy = np.diff(x), np.diff(y) - ax.quiver( - x[:-1], y[:-1], dx, dy, scale_units="xy", angles="xy", scale=1, color="k" - ) - - # Setup limits and show - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) + def render(td: TensorDict, actions: torch.Tensor=None, ax = None): + return render(td, actions, ax) diff --git a/rl4co/envs/routing/tsp/generator.py b/rl4co/envs/routing/tsp/generator.py new file mode 100644 index 00000000..f2c53f40 --- /dev/null +++ b/rl4co/envs/routing/tsp/generator.py @@ -0,0 +1,55 @@ +from typing import Union, Callable + +import torch + +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class TSPGenerator(Generator): + """Data generator for the Travelling Salesman Problem (TSP). + Args: + num_loc: number of locations (customers) in the TSP + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + loc_distribution: distribution for the location coordinates + + Returns: + A TensorDict with the following keys: + locs [batch_size, num_loc, 2]: locations of each customer + """ + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0.0, + max_loc: float = 1.0, + loc_distribution: Union[ + int, float, str, type, Callable + ] = Uniform, + **kwargs + ): + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + + # Location distribution + if kwargs.get("loc_sampler", None) is not None: + self.loc_sampler = kwargs["loc_sampler"] + else: + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) + + def _generate(self, batch_size) -> TensorDict: + # Sample locations + locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2)) + + return TensorDict( + { + "locs": locs, + }, + batch_size=batch_size, + ) diff --git a/rl4co/envs/routing/tsp/render.py b/rl4co/envs/routing/tsp/render.py new file mode 100644 index 00000000..8ad0a903 --- /dev/null +++ b/rl4co/envs/routing/tsp/render.py @@ -0,0 +1,50 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td, actions=None, ax=None): + if ax is None: + # Create a plot of the nodes + _, ax = plt.subplots() + + td = td.detach().cpu() + + if actions is None: + actions = td.get("action", None) + + # If batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + locs = td["locs"] + + # Gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + actions = actions.detach().cpu() + locs = gather_by_index(locs, actions, dim=0) + + # Cat the first node to the end to complete the tour + locs = torch.cat((locs, locs[0:1])) + x, y = locs[:, 0], locs[:, 1] + + # Plot the visited nodes + ax.scatter(x, y, color="tab:blue") + + # Add arrows between visited nodes as a quiver plot + dx, dy = np.diff(x), np.diff(y) + ax.quiver( + x[:-1], y[:-1], dx, dy, scale_units="xy", angles="xy", scale=1, color="k" + ) + + # Setup limits and show + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) diff --git a/rl4co/envs/scheduling/__init__.py b/rl4co/envs/scheduling/__init__.py index a9c5144b..1c63820f 100644 --- a/rl4co/envs/scheduling/__init__.py +++ b/rl4co/envs/scheduling/__init__.py @@ -1,2 +1,2 @@ -from rl4co.envs.scheduling.ffsp import FFSPEnv -from rl4co.envs.scheduling.smtwtp import SMTWTPEnv +from rl4co.envs.scheduling.ffsp.env import FFSPEnv +from rl4co.envs.scheduling.smtwtp.env import SMTWTPEnv diff --git a/rl4co/envs/scheduling/ffsp.py b/rl4co/envs/scheduling/ffsp/env.py similarity index 80% rename from rl4co/envs/scheduling/ffsp.py rename to rl4co/envs/scheduling/ffsp/env.py index 7c614064..f3037e18 100644 --- a/rl4co/envs/scheduling/ffsp.py +++ b/rl4co/envs/scheduling/ffsp/env.py @@ -13,6 +13,9 @@ UnboundedDiscreteTensorSpec, ) +from .generator import FFSPGenerator +from .render import render + from rl4co.envs.common.base import RL4COEnvBase @@ -20,40 +23,55 @@ class FFSPEnv(RL4COEnvBase): """Flexible Flow Shop Problem (FFSP) environment. The goal is to schedule a set of jobs on a set of machines such that the makespan is minimized. + Observations: + - time index + - sub time index + - batch index + - machine index + - schedule + - machine wait step + - job location + - job wait step + - job duration + + Constraints: + - each job has to be processed on each machine in a specific order + - the machine has to be available to process the job + - the job has to be available to be processed + + Finish Condition: + - all jobs are scheduled + + Reward: + - (minus) the makespan of the schedule + Args: - num_stage: number of stages - num_machine: number of machines in each stage - num_job: number of jobs - min_time: minimum processing time of a job - max_time: maximum processing time of a job - batch_size: batch size of the problem - - Note: - - [IMPORTANT] This version of ffsp requires the number of machines in each stage to be the same + generator: FFSPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "ffsp" def __init__( self, - num_stage: int, - num_machine: int, - num_job: int, - min_time: int = 2, - max_time: int = 10, - flatten_stages: bool = True, + generator: FFSPGenerator = None, + generator_params: dict = {}, **kwargs, ): super().__init__(**kwargs) - self.num_stage = num_stage - self.num_machine = num_machine - self.num_machine_total = num_stage * num_machine - self.num_job = num_job - self.min_time = min_time - self.max_time = max_time - self.flatten_stages = flatten_stages + if generator is None: + generator = FFSPGenerator(**generator_params) + self.generator = generator + + self.num_stage = generator.num_stage + self.num_machine = generator.num_machine + self.num_job = generator.num_job + self.num_machine_total = generator.num_machine_total self.tables = None self.step_cnt = None + self.flatten_stages = generator.flatten_stages + + self._make_spec(generator) def get_num_starts(self, td): return factorial(self.num_machine) @@ -255,51 +273,46 @@ def _reset( - job_wait_step [batch_size, num_job+1] - job_duration [batch_size, num_job+1, num_machine * num_stage] """ - if batch_size is None: - batch_size = self.batch_size if td is None else td.batch_size - - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) + device = td.device self.step_cnt = 0 - self.to(td.device) self.tables = IndexTables(self) # reset tables to undo the augmentation # self.tables._reset(device=self.device) self.tables.set_bs(batch_size[0]) # Init index record tensor - time_idx = torch.zeros(size=(*batch_size,), dtype=torch.long, device=self.device) + time_idx = torch.zeros(size=(*batch_size,), dtype=torch.long, device=device) sub_time_idx = torch.zeros( - size=(*batch_size,), dtype=torch.long, device=self.device + size=(*batch_size,), dtype=torch.long, device=device ) # Scheduling status information schedule = torch.full( size=(*batch_size, self.num_machine_total, self.num_job + 1), dtype=torch.long, - device=self.device, + device=device, fill_value=-999999, ) machine_wait_step = torch.zeros( size=(*batch_size, self.num_machine_total), dtype=torch.long, - device=self.device, + device=device, ) job_location = torch.zeros( size=(*batch_size, self.num_job + 1), dtype=torch.long, - device=self.device, + device=device, ) job_wait_step = torch.zeros( size=(*batch_size, self.num_job + 1), dtype=torch.long, - device=self.device, + device=device, ) job_duration = torch.empty( size=(*batch_size, self.num_job + 1, self.num_machine * self.num_stage), dtype=torch.long, - device=self.device, + device=device, ) if self.flatten_stages: assert ( @@ -322,18 +335,18 @@ def _reset( reward = torch.full( size=(*batch_size,), dtype=torch.float32, - device=self.device, + device=device, fill_value=float("-inf"), ) done = torch.full( size=(*batch_size,), dtype=torch.bool, - device=self.device, + device=device, fill_value=False, ) action_mask = torch.ones( - size=(*batch_size, self.num_job + 1), dtype=bool, device=self.device + size=(*batch_size, self.num_job + 1), dtype=bool, device=device ) action_mask[..., -1] = 0 @@ -364,7 +377,7 @@ def _reset( batch_size=batch_size, ) - def _make_spec(self, td_params: TensorDict): + def _make_spec(self, generator: FFSPGenerator): self.observation_spec = CompositeSpec( time_idx=UnboundedDiscreteTensorSpec( shape=(1,), @@ -383,23 +396,23 @@ def _make_spec(self, td_params: TensorDict): dtype=torch.int64, ), schedule=UnboundedDiscreteTensorSpec( - shape=(self.num_machine_total, self.num_job + 1), + shape=(generator.num_machine_total, generator.num_job + 1), dtype=torch.int64, ), machine_wait_step=UnboundedDiscreteTensorSpec( - shape=(self.num_machine_total), + shape=(generator.num_machine_total), dtype=torch.int64, ), job_location=UnboundedDiscreteTensorSpec( - shape=(self.num_job + 1), + shape=(generator.num_job + 1), dtype=torch.int64, ), job_wait_step=UnboundedDiscreteTensorSpec( - shape=(self.num_job + 1), + shape=(generator.num_job + 1), dtype=torch.int64, ), job_duration=UnboundedDiscreteTensorSpec( - shape=(self.num_job + 1, self.num_machine * self.num_stage), + shape=(generator.num_job + 1, generator.num_machine * generator.num_stage), dtype=torch.int64, ), shape=(), @@ -408,85 +421,14 @@ def _make_spec(self, td_params: TensorDict): shape=(1,), dtype=torch.int64, low=0, - high=self.num_loc, + high=generator.num_machine_total, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - def get_reward(self, td, actions) -> TensorDict: + def _get_reward(self, td, actions) -> TensorDict: return td["reward"] - def generate_data(self, batch_size) -> TensorDict: - # Batch size input check - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - # Init observation: running time of each job on each machine - run_time = torch.randint( - low=self.min_time, - high=self.max_time, - size=(*batch_size, self.num_job, self.num_machine, self.num_stage), - ).to(self.device) - - if self.flatten_stages: - run_time = ( - run_time.transpose(-2, -1) - .contiguous() - .view(*batch_size, self.num_job, self.num_machine_total) - ) - - return TensorDict( - { - "run_time": run_time, - }, - batch_size=batch_size, - ) - - def render(self, td: TensorDict, idx: int): - import matplotlib.patches as patches - import matplotlib.pyplot as plt - - job_durations = td["job_duration"][idx, :, :] - # shape: (job, machine) - schedule = td["schedule"][idx, :, :] - # shape: (machine, job) - - total_machine_cnt = self.num_machine_total - makespan = -td["reward"][idx].item() - - # Create figure and axes - fig, ax = plt.subplots(figsize=(makespan / 3, 5)) - cmap = self._get_cmap(self.num_job) - - plt.xlim(0, makespan) - plt.ylim(0, total_machine_cnt) - ax.invert_yaxis() - - plt.plot([0, makespan], [4, 4], "black") - plt.plot([0, makespan], [8, 8], "black") - - for machine_idx in range(total_machine_cnt): - duration = job_durations[:, machine_idx] - # shape: (job) - machine_schedule = schedule[machine_idx, :] - # shape: (job) - - for job_idx in range(self.num_job): - job_length = duration[job_idx].item() - job_start_time = machine_schedule[job_idx].item() - if job_start_time >= 0: - # Create a Rectangle patch - rect = patches.Rectangle( - (job_start_time, machine_idx), - job_length, - 1, - facecolor=cmap(job_idx), - ) - ax.add_patch(rect) - - ax.grid() - ax.set_axisbelow(True) - plt.show() - def _get_cmap(self, color_cnt): from random import shuffle diff --git a/rl4co/envs/scheduling/ffsp/generator.py b/rl4co/envs/scheduling/ffsp/generator.py new file mode 100644 index 00000000..8ec3ce71 --- /dev/null +++ b/rl4co/envs/scheduling/ffsp/generator.py @@ -0,0 +1,79 @@ +import os +import zipfile +from typing import Union, Callable + +import torch +import numpy as np + +from robust_downloader import download +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.data.utils import load_npz_to_tensordict +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class FFSPGenerator(Generator): + """Data generator for the Flow Shop Scheduling Problem (FFSP). + + Args: + num_stage: number of stages + num_machine: number of machines + num_job: number of jobs + min_time: minimum running time of each job on each machine + max_time: maximum running time of each job on each machine + flatten_stages: whether to flatten the stages + + Returns: + A TensorDict with the following key: + run_time [batch_size, num_job, num_machine, num_stage]: running time of each job on each machine + + Note: + - [IMPORTANT] This version of ffsp requires the number of machines in each stage to be the same + """ + def __init__( + self, + num_stage: int = 2, + num_machine: int = 3, + num_job: int = 4, + min_time: int = 2, + max_time: int = 10, + flatten_stages: bool = True, + **unused_kwargs + ): + self.num_stage = num_stage + self.num_machine = num_machine + self.num_machine_total = num_machine * num_stage + self.num_job = num_job + self.min_time = min_time + self.max_time = max_time + self.flatten_stages = flatten_stages + + # FFSP environment doen't have any other kwargs + if len(unused_kwargs) > 0: + log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") + + def _generate(self, batch_size) -> TensorDict: + # Init observation: running time of each job on each machine + run_time = torch.randint( + low=self.min_time, + high=self.max_time, + size=(*batch_size, self.num_job, self.num_machine, self.num_stage), + ) + + if self.flatten_stages: + run_time = ( + run_time.transpose(-2, -1) + .contiguous() + .view(*batch_size, self.num_job, self.num_machine_total) + ) + + return TensorDict( + { + "run_time": run_time, + }, + batch_size=batch_size, + ) diff --git a/rl4co/envs/scheduling/ffsp/render.py b/rl4co/envs/scheduling/ffsp/render.py new file mode 100644 index 00000000..992f3ad4 --- /dev/null +++ b/rl4co/envs/scheduling/ffsp/render.py @@ -0,0 +1,72 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import cm, colormaps +from tensordict.tensordict import TensorDict + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td: TensorDict, idx: int): + import matplotlib.patches as patches + import matplotlib.pyplot as plt + + # TODO: fix this render function parameters + num_machine_total = td["num_machine_total"][idx].item() + num_job = td["num_job"][idx].item() + + job_durations = td["job_duration"][idx, :, :] + # shape: (job, machine) + schedule = td["schedule"][idx, :, :] + # shape: (machine, job) + + total_machine_cnt = num_machine_total + makespan = -td["reward"][idx].item() + + # Create figure and axes + fig, ax = plt.subplots(figsize=(makespan / 3, 5)) + cmap = _get_cmap(num_job) + + plt.xlim(0, makespan) + plt.ylim(0, total_machine_cnt) + ax.invert_yaxis() + + plt.plot([0, makespan], [4, 4], "black") + plt.plot([0, makespan], [8, 8], "black") + + for machine_idx in range(total_machine_cnt): + duration = job_durations[:, machine_idx] + # shape: (job) + machine_schedule = schedule[machine_idx, :] + # shape: (job) + + for job_idx in range(num_job): + job_length = duration[job_idx].item() + job_start_time = machine_schedule[job_idx].item() + if job_start_time >= 0: + # Create a Rectangle patch + rect = patches.Rectangle( + (job_start_time, machine_idx), + job_length, + 1, + facecolor=cmap(job_idx), + ) + ax.add_patch(rect) + + ax.grid() + ax.set_axisbelow(True) + plt.show() + +def _get_cmap(color_cnt): + from random import shuffle + + from matplotlib.colors import CSS4_COLORS, ListedColormap + + color_list = list(CSS4_COLORS.keys()) + shuffle(color_list) + cmap = ListedColormap(color_list, N=color_cnt) + return cmap diff --git a/rl4co/envs/scheduling/smtwtp.py b/rl4co/envs/scheduling/smtwtp/env.py similarity index 56% rename from rl4co/envs/scheduling/smtwtp.py rename to rl4co/envs/scheduling/smtwtp/env.py index 25f461af..8bc311ea 100644 --- a/rl4co/envs/scheduling/smtwtp.py +++ b/rl4co/envs/scheduling/smtwtp/env.py @@ -13,6 +13,9 @@ from rl4co.envs.common.base import RL4COEnvBase from rl4co.utils.pylogger import get_pylogger +from .generator import SMTWTPGenerator +from .render import render + log = get_pylogger(__name__) @@ -25,42 +28,48 @@ class SMTWTPEnv(RL4COEnvBase): At each step, the agent chooses a job to process. The reward is 0 unless the agent processes all the jobs. In that case, the reward is (-)objective value of the processing order: maximizing the reward is equivalent to minimizing the objective. + Observation: + - job_due_time: the due time of each job + - job_weight: the weight of each job + - job_process_time: the process time of each job + - current_node: the current node + - action_mask: a mask of available actions + - current_time: the current time + + Constants: + - num_job: number of jobs + - min_time_span: lower bound of jobs' due time. By default, jobs' due time is uniformly sampled from (min_time_span, max_time_span) + - max_time_span: upper bound of jobs' due time. By default, it will be set to num_job / 2 + - min_job_weight: lower bound of jobs' weights. By default, jobs' weights are uniformly sampled from (min_job_weight, max_job_weight) + - max_job_weight: upper bound of jobs' weights + - min_process_time: lower bound of jobs' process time. By default, jobs' process time is uniformly sampled from (min_process_time, max_process_time) + - max_process_time: upper bound of jobs' process time + + Finishing condition: + - All jobs are processed + + Reward: + - The reward is 0 unless the agent processes all the jobs. + - In that case, the reward is (-)objective value of the processing order: maximizing the reward is equivalent to minimizing the objective. + Args: - num_job: number of jobs - min_time_span: lower bound of jobs' due time. By default, jobs' due time is uniformly sampled from (min_time_span, max_time_span) - max_time_span: upper bound of jobs' due time. By default, it will be set to num_job / 2 - min_job_weight: lower bound of jobs' weights. By default, jobs' weights are uniformly sampled from (min_job_weight, max_job_weight) - max_job_weight: upper bound of jobs' weights - min_process_time: lower bound of jobs' process time. By default, jobs' process time is uniformly sampled from (min_process_time, max_process_time) - max_process_time: upper bound of jobs' process time - td_params: parameters of the environment - seed: seed for the environment - device: device to use. Generally, no need to set as tensors are updated on the fly + generator: FFSPGenerator instance as the data generator + generator_params: parameters for the generator """ name = "smtwtp" def __init__( self, - num_job: int = 10, - min_time_span: float = 0, - max_time_span: float = None, # will be set to num_job/2 by default. In DeepACO, it is set to num_job, which would be too simple - min_job_weight: float = 0, - max_job_weight: float = 1, - min_process_time: float = 0, - max_process_time: float = 1, - td_params: TensorDict = None, + generator: SMTWTPGenerator = None, + generator_params: dict = {}, **kwargs, ): super().__init__(**kwargs) - self.num_job = num_job - self.min_time_span = min_time_span - self.max_time_span = num_job / 2 if max_time_span is None else max_time_span - self.min_job_weight = min_job_weight - self.max_job_weight = max_job_weight - self.min_process_time = min_process_time - self.max_process_time = max_process_time - self._make_spec(td_params) + if generator is None: + generator = SMTWTPGenerator(**generator_params) + self.generator = generator + self._make_spec(self.generator) @staticmethod def _step(td: TensorDict) -> TensorDict: @@ -95,15 +104,7 @@ def _step(td: TensorDict) -> TensorDict: return td def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: - # Initialization - if batch_size is None: - batch_size = self.batch_size if td is None else td["job_due_time"].shape[:-1] - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - device = td["job_due_time"].device if td is not None else self.device - self.to(device) - - td = self.generate_data(batch_size) if td is None else td + device = td.device init_job_due_time = td["job_due_time"] init_job_process_time = td["job_process_time"] @@ -113,7 +114,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict current_job = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) current_time = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) available = torch.ones( - (*batch_size, self.num_job + 1), dtype=torch.bool, device=device + (*batch_size, self.generator.num_job + 1), dtype=torch.bool, device=device ) available[:, 0] = 0 # mask the starting dummy node @@ -129,24 +130,24 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict batch_size=batch_size, ) - def _make_spec(self, td_params: TensorDict = None): + def _make_spec(self, generator: SMTWTPGenerator) -> None: self.observation_spec = CompositeSpec( job_due_time=BoundedTensorSpec( - low=self.min_time_span, - high=self.max_time_span, - shape=(self.num_job + 1,), + low=generator.min_time_span, + high=generator.max_time_span, + shape=(generator.num_job + 1,), dtype=torch.float32, ), job_weight=BoundedTensorSpec( - low=self.min_job_weight, - high=self.max_job_weight, - shape=(self.num_job + 1,), + low=generator.min_job_weight, + high=generator.max_job_weight, + shape=(generator.num_job + 1,), dtype=torch.float32, ), job_process_time=BoundedTensorSpec( - low=self.min_process_time, - high=self.max_process_time, - shape=(self.num_job + 1,), + low=generator.min_process_time, + high=generator.max_process_time, + shape=(generator.num_job + 1,), dtype=torch.float32, ), current_node=UnboundedDiscreteTensorSpec( @@ -154,7 +155,7 @@ def _make_spec(self, td_params: TensorDict = None): dtype=torch.int64, ), action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_job + 1,), + shape=(generator.num_job + 1,), dtype=torch.bool, ), current_time=UnboundedContinuousTensorSpec( @@ -167,12 +168,12 @@ def _make_spec(self, td_params: TensorDict = None): shape=(1,), dtype=torch.int64, low=0, - high=self.num_job + 1, + high=generator.num_job + 1, ) self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - def get_reward(self, td, actions) -> TensorDict: + def _get_reward(self, td, actions) -> TensorDict: job_due_time = td["job_due_time"] job_weight = td["job_weight"] job_process_time = td["job_process_time"] @@ -193,39 +194,10 @@ def get_reward(self, td, actions) -> TensorDict: return -job_weighted_tardiness.sum(-1) - def generate_data(self, batch_size) -> TensorDict: - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - # Sampling according to Ye et al. (2023) - job_due_time = ( - torch.FloatTensor(*batch_size, self.num_job + 1) - .uniform_(self.min_time_span, self.max_time_span) - .to(self.device) - ) - job_weight = ( - torch.FloatTensor(*batch_size, self.num_job + 1) - .uniform_(self.min_job_weight, self.max_job_weight) - .to(self.device) - ) - job_process_time = ( - torch.FloatTensor(*batch_size, self.num_job + 1) - .uniform_(self.min_process_time, self.max_process_time) - .to(self.device) - ) - - # Rollouts begin at dummy node 0, whose features are set to 0 - job_due_time[:, 0] = 0 - job_weight[:, 0] = 0 - job_process_time[:, 0] = 0 - - return TensorDict( - { - "job_due_time": job_due_time, - "job_weight": job_weight, - "job_process_time": job_process_time, - }, - batch_size=batch_size, - ) + def check_solution_validity(self, td, actions): + log.warning("Checking solution validity is not implemented for SMTWTP") + pass @staticmethod def render(td, actions=None, ax=None): - raise NotImplementedError("TODO: render is not implemented yet") + raise render(td, actions, ax) diff --git a/rl4co/envs/scheduling/smtwtp/generator.py b/rl4co/envs/scheduling/smtwtp/generator.py new file mode 100644 index 00000000..39701478 --- /dev/null +++ b/rl4co/envs/scheduling/smtwtp/generator.py @@ -0,0 +1,88 @@ +import os +import zipfile +from typing import Union, Callable + +import torch +import numpy as np + +from robust_downloader import download +from torch.distributions import Uniform +from tensordict.tensordict import TensorDict + +from rl4co.data.utils import load_npz_to_tensordict +from rl4co.utils.pylogger import get_pylogger +from rl4co.envs.common.utils import get_sampler, Generator + +log = get_pylogger(__name__) + + +class SMTWTPGenerator(Generator): + """Data generator for the Single Machine Total Weighted Tardiness Problem (SMTWTP) environment + + Args: + num_job: number of jobs + min_time_span: lower bound of jobs' due time. By default, jobs' due time is uniformly sampled from (min_time_span, max_time_span) + max_time_span: upper bound of jobs' due time. By default, it will be set to num_job / 2 + min_job_weight: lower bound of jobs' weights. By default, jobs' weights are uniformly sampled from (min_job_weight, max_job_weight) + max_job_weight: upper bound of jobs' weights + min_process_time: lower bound of jobs' process time. By default, jobs' process time is uniformly sampled from (min_process_time, max_process_time) + max_process_time: upper bound of jobs' process time + + Returns: + A TensorDict with the following key: + job_due_time [batch_size, num_job + 1]: the due time of each job + job_weight [batch_size, num_job + 1]: the weight of each job + job_process_time [batch_size, num_job + 1]: the process time of each job + """ + def __init__( + self, + num_job: int = 10, + min_time_span: float = 0, + max_time_span: float = None, # will be set to num_job / 2 by default. In DeepACO, it is set to num_job, which would be too simple + min_job_weight: float = 0, + max_job_weight: float = 1, + min_process_time: float = 0, + max_process_time: float = 1, + **unused_kwargs + ): + self.num_job = num_job + self.min_time_span = min_time_span + self.max_time_span = num_job / 2 if max_time_span is None else max_time_span + self.min_job_weight = min_job_weight + self.max_job_weight = max_job_weight + self.min_process_time = min_process_time + self.max_process_time = max_process_time + + # SMTWTP environment doen't have any other kwargs + if len(unused_kwargs) > 0: + log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") + + def _generate(self, batch_size) -> TensorDict: + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + # Sampling according to Ye et al. (2023) + job_due_time = ( + torch.FloatTensor(*batch_size, self.num_job + 1) + .uniform_(self.min_time_span, self.max_time_span) + ) + job_weight = ( + torch.FloatTensor(*batch_size, self.num_job + 1) + .uniform_(self.min_job_weight, self.max_job_weight) + ) + job_process_time = ( + torch.FloatTensor(*batch_size, self.num_job + 1) + .uniform_(self.min_process_time, self.max_process_time) + ) + + # Rollouts begin at dummy node 0, whose features are set to 0 + job_due_time[:, 0] = 0 + job_weight[:, 0] = 0 + job_process_time[:, 0] = 0 + + return TensorDict( + { + "job_due_time": job_due_time, + "job_weight": job_weight, + "job_process_time": job_process_time, + }, + batch_size=batch_size, + ) diff --git a/rl4co/envs/scheduling/smtwtp/render.py b/rl4co/envs/scheduling/smtwtp/render.py new file mode 100644 index 00000000..9f8eedf0 --- /dev/null +++ b/rl4co/envs/scheduling/smtwtp/render.py @@ -0,0 +1,15 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib import cm, colormaps +from tensordict.tensordict import TensorDict + +from rl4co.utils.ops import gather_by_index +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td: TensorDict, actions=None, ax=None): + raise NotImplementedError diff --git a/rl4co/models/zoo/matnet/policy.py b/rl4co/models/zoo/matnet/policy.py index 0894cd69..24212310 100644 --- a/rl4co/models/zoo/matnet/policy.py +++ b/rl4co/models/zoo/matnet/policy.py @@ -6,7 +6,7 @@ from tensordict import TensorDict -from rl4co.envs.scheduling.ffsp import FFSPEnv +from rl4co.envs.scheduling.ffsp.env import FFSPEnv from rl4co.models.common.constructive.autoregressive import AutoregressivePolicy from rl4co.models.zoo.matnet.decoder import ( MatNetDecoder, diff --git a/tests/test_envs.py b/tests/test_envs.py index a3101054..e892718f 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -61,13 +61,8 @@ def test_eda(env_cls, batch_size=2, max_decaps=5): @pytest.mark.parametrize("env_cls", [FFSPEnv]) def test_scheduling(env_cls, batch_size=2): - env = env_cls( - num_stage=2, - num_machine=3, - num_job=4, - batch_size=[batch_size], - ) - td = env.reset() + env = env_cls() + td = env.reset(batch_size=[batch_size]) td["action"] = torch.tensor([1, 1]) td = env._step(td)