Skip to content

Don't artificially limit the rendering axes for CVRP environment #264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion rl4co/envs/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@

def render(self, *args, **kwargs):
"""Render the environment"""
raise NotImplementedError
raise NotImplementedError(

Check warning on line 285 in rl4co/envs/common/base.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/common/base.py#L285

Added line #L285 was not covered by tests
f"Render is not implemented for {self.name} environment. Please implement the `render` method in the subclass."
)

@staticmethod
def load_data(fpath, batch_size=[]):
Expand Down
14 changes: 5 additions & 9 deletions rl4co/envs/eda/dpp/render.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch

from matplotlib import cm, colormaps

from rl4co.utils.ops import gather_by_index
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand All @@ -15,7 +12,6 @@ def render(self, decaps, probe, action_mask, ax=None, legend=True):
Plot a grid of 1x1 squares representing the environment.
The keepout regions are the action_mask - decaps - probe
"""
import matplotlib.pyplot as plt

settings = {
0: {"color": "white", "label": "available"},
Expand Down Expand Up @@ -58,9 +54,7 @@ def render(self, decaps, probe, action_mask, ax=None, legend=True):
ax.add_patch(plt.Rectangle((x, y), 1, 1, color=color, linestyle="-"))

# Add grid with 1x1 squares
ax.grid(
which="major", axis="both", linestyle="-", color="k", linewidth=1, alpha=0.5
)
ax.grid(which="major", axis="both", linestyle="-", color="k", linewidth=1, alpha=0.5)
# set 10 ticks
ax.set_xticks(np.arange(0, xdim, 1))
ax.set_yticks(np.arange(0, ydim, 1))
Expand All @@ -82,3 +76,5 @@ def render(self, decaps, probe, action_mask, ax=None, legend=True):
loc="upper center",
bbox_to_anchor=(0.5, 1.1),
)

return ax
19 changes: 6 additions & 13 deletions rl4co/envs/eda/mdpp/render.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch

from matplotlib import cm, colormaps

from rl4co.utils.ops import gather_by_index
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand All @@ -15,8 +12,6 @@ def render(self, td, actions=None, ax=None, legend=True, settings=None):
The keepout regions are the action_mask - decaps - probe
"""

import matplotlib.pyplot as plt

from matplotlib.lines import Line2D
from matplotlib.patches import Annulus, Rectangle, RegularPolygon

Expand Down Expand Up @@ -87,9 +82,7 @@ def draw_keepout(ax, x, y, color="black"):
# Backgrund rectangle: same as color but with alpha=0.5
ax.add_patch(Rectangle((x, y), 1, 1, color=color, alpha=0.5))
ax.add_patch(
RegularPolygon(
(x + 0.5, y + 0.5), numVertices=6, radius=0.45, color=color
)
RegularPolygon((x + 0.5, y + 0.5), numVertices=6, radius=0.45, color=color)
)

size = self.size
Expand Down Expand Up @@ -132,9 +125,7 @@ def draw_keepout(ax, x, y, color="black"):
elif keepout[i, j] == 1:
draw_keepout(ax, x, y, color=settings["keepout"]["color"])

ax.grid(
which="major", axis="both", linestyle="-", color="k", linewidth=1, alpha=0.5
)
ax.grid(which="major", axis="both", linestyle="-", color="k", linewidth=1, alpha=0.5)
# set 10 ticks
ax.set_xticks(np.arange(0, xdim, 1))
ax.set_yticks(np.arange(0, ydim, 1))
Expand All @@ -159,3 +150,5 @@ def draw_keepout(ax, x, y, color="black"):
loc="upper center",
bbox_to_anchor=(0.5, 1.1),
)

return ax
12 changes: 4 additions & 8 deletions rl4co/envs/routing/atsp/render.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch

from rl4co.utils.ops import gather_by_index
from rl4co.utils.pylogger import get_pylogger
Expand Down Expand Up @@ -41,10 +41,6 @@

# Add arrows between visited nodes as a quiver plot
dx, dy = np.diff(x), np.diff(y)
ax.quiver(
x[:-1], y[:-1], dx, dy, scale_units="xy", angles="xy", scale=1, color="k"
)
ax.quiver(x[:-1], y[:-1], dx, dy, scale_units="xy", angles="xy", scale=1, color="k")

Check warning on line 44 in rl4co/envs/routing/atsp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/atsp/render.py#L44

Added line #L44 was not covered by tests

# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax

Check warning on line 46 in rl4co/envs/routing/atsp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/atsp/render.py#L46

Added line #L46 was not covered by tests
3 changes: 1 addition & 2 deletions rl4co/envs/routing/cvrp/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,4 @@
annotation_clip=False,
)

ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax

Check warning on line 138 in rl4co/envs/routing/cvrp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/cvrp/render.py#L138

Added line #L138 was not covered by tests
2 changes: 1 addition & 1 deletion rl4co/envs/routing/cvrptw/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@

@staticmethod
def render(td: TensorDict, actions: torch.Tensor = None, ax=None):
render(td, actions, ax)
return render(td, actions, ax)

Check warning on line 218 in rl4co/envs/routing/cvrptw/env.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/cvrptw/env.py#L218

Added line #L218 was not covered by tests

@staticmethod
def load_data(
Expand Down
8 changes: 3 additions & 5 deletions rl4co/envs/routing/cvrptw/render.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch

from matplotlib import cm, colormaps

from rl4co.utils.ops import gather_by_index
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand Down Expand Up @@ -129,5 +128,4 @@
annotation_clip=False,
)

ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax

Check warning on line 131 in rl4co/envs/routing/cvrptw/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/cvrptw/render.py#L131

Added line #L131 was not covered by tests
15 changes: 8 additions & 7 deletions rl4co/envs/routing/mdcpdp/render.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch

from tensordict.tensordict import TensorDict


def render(td: TensorDict, actions=None, ax=None):
import matplotlib.pyplot as plt

markersize = 8

td = td.detach().cpu()
Expand All @@ -28,10 +31,10 @@

# Plot the actions in order
last_depot = 0
for i in range(len(actions)-1):
if actions[i+1] < n_depots:
last_depot = actions[i+1]
if actions[i] < n_depots and actions[i+1] < n_depots:
for i in range(len(actions) - 1):
if actions[i + 1] < n_depots:
last_depot = actions[i + 1]
if actions[i] < n_depots and actions[i + 1] < n_depots:

Check warning on line 37 in rl4co/envs/routing/mdcpdp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/mdcpdp/render.py#L34-L37

Added lines #L34 - L37 were not covered by tests
continue
from_node = actions[i]
to_node = (
Expand Down Expand Up @@ -115,6 +118,4 @@
alpha=0.5,
)

# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax

Check warning on line 121 in rl4co/envs/routing/mdcpdp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/mdcpdp/render.py#L121

Added line #L121 was not covered by tests
14 changes: 3 additions & 11 deletions rl4co/envs/routing/mpdp/render.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch

from matplotlib import cm, colormaps

from rl4co.utils.ops import gather_by_index
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand All @@ -13,11 +12,6 @@
def render(td, actions=None, ax=None):
# TODO: color switch with new agents; add pickup and delivery nodes as in `PDPEnv.render`

import matplotlib.pyplot as plt
import numpy as np

from matplotlib import cm, colormaps

num_routine = (actions == 0).sum().item() + 2
base = colormaps["nipy_spectral"]
color_list = base(np.linspace(0, 1, num_routine))
Expand Down Expand Up @@ -109,6 +103,4 @@ def render(td, actions=None, ax=None):
annotation_clip=False,
)

# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax
6 changes: 4 additions & 2 deletions rl4co/envs/routing/mtsp/render.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch

from matplotlib import colormaps

Expand Down Expand Up @@ -93,3 +93,5 @@
ax.set_title("mTSP")
ax.set_xlabel("x-coordinate")
ax.set_ylabel("y-coordinate")

return ax

Check warning on line 97 in rl4co/envs/routing/mtsp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/mtsp/render.py#L97

Added line #L97 was not covered by tests
7 changes: 2 additions & 5 deletions rl4co/envs/routing/mtvrp/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def render(
td: TensorDict, actions=None, ax=None, scale_xy: bool = True, vehicle_capacity=None
td: TensorDict, actions=None, ax=None, scale_xy: bool = False, vehicle_capacity=None
):
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -139,7 +139,4 @@ def render(
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)

# Remove the ticks
ax.set_xticks([])
ax.set_yticks([])
plt.show()
return ax
14 changes: 4 additions & 10 deletions rl4co/envs/routing/op/render.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch

from matplotlib import cm, colormaps

from rl4co.utils.ops import gather_by_index
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand Down Expand Up @@ -32,8 +29,7 @@
customers = td["locs"][1:, :]
prizes = td["prize"][1:]
normalized_prizes = (
200 * (prizes - torch.min(prizes)) / (torch.max(prizes) - torch.min(prizes))
+ 10
200 * (prizes - torch.min(prizes)) / (torch.max(prizes) - torch.min(prizes)) + 10
)

# Plot depot and customers with prize
Expand Down Expand Up @@ -81,6 +77,4 @@
width=0.0035,
)

# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax

Check warning on line 80 in rl4co/envs/routing/op/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/op/render.py#L80

Added line #L80 was not covered by tests
7 changes: 2 additions & 5 deletions rl4co/envs/routing/pctsp/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
prizes = td["real_prize"][1:]
penalties = td["penalty"][1:]
normalized_prizes = (
200 * (prizes - torch.min(prizes)) / (torch.max(prizes) - torch.min(prizes))
+ 10
200 * (prizes - torch.min(prizes)) / (torch.max(prizes) - torch.min(prizes)) + 10
)
normalized_penalties = (
3
Expand Down Expand Up @@ -88,6 +87,4 @@
width=0.0035,
)

# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax

Check warning on line 90 in rl4co/envs/routing/pctsp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/pctsp/render.py#L90

Added line #L90 was not covered by tests
6 changes: 3 additions & 3 deletions rl4co/envs/routing/pdp/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@
label="Delivery" if i == 0 else None,
)

# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax

Check warning on line 78 in rl4co/envs/routing/pdp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/pdp/render.py#L78

Added line #L78 was not covered by tests


def render_improvement(td, current_soltuion, best_soltuion):
Expand Down Expand Up @@ -140,3 +138,5 @@
)
ax.set_title("Best Solution")
plt.tight_layout()

return ax

Check warning on line 142 in rl4co/envs/routing/pdp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/pdp/render.py#L142

Added line #L142 was not covered by tests
4 changes: 1 addition & 3 deletions rl4co/envs/routing/shpp/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,4 @@
headwidth=8,
)

# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax

Check warning on line 62 in rl4co/envs/routing/shpp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/shpp/render.py#L62

Added line #L62 was not covered by tests
1 change: 1 addition & 0 deletions rl4co/envs/routing/svrp/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,4 @@ def render(td, actions=None, ax=None):
size=15,
annotation_clip=False,
)
return ax
6 changes: 3 additions & 3 deletions rl4co/envs/routing/tsp/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
dx, dy = np.diff(x), np.diff(y)
ax.quiver(x[:-1], y[:-1], dx, dy, scale_units="xy", angles="xy", scale=1, color="k")

# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
return ax

Check warning on line 46 in rl4co/envs/routing/tsp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/tsp/render.py#L46

Added line #L46 was not covered by tests


def render_improvement(td, current_soltuion, best_soltuion):
Expand Down Expand Up @@ -108,3 +106,5 @@
)
ax.set_title("Best Solution")
plt.tight_layout()

return ax

Check warning on line 110 in rl4co/envs/routing/tsp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/routing/tsp/render.py#L110

Added line #L110 was not covered by tests
8 changes: 2 additions & 6 deletions rl4co/envs/scheduling/ffsp/render.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import torch
import numpy as np
import matplotlib.pyplot as plt

from matplotlib import cm, colormaps
from tensordict.tensordict import TensorDict

from rl4co.utils.ops import gather_by_index
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


def render(td: TensorDict, idx: int):
import matplotlib.patches as patches
import matplotlib.pyplot as plt

# TODO: fix this render function parameters
num_machine_total = td["num_machine_total"][idx].item()
Expand Down Expand Up @@ -59,7 +54,8 @@ def render(td: TensorDict, idx: int):

ax.grid()
ax.set_axisbelow(True)
plt.show()
return ax


def _get_cmap(color_cnt):
from random import shuffle
Expand Down
3 changes: 1 addition & 2 deletions rl4co/envs/scheduling/fjsp/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,4 @@
)

plt.tight_layout()
# Show the Gantt chart
plt.show()
return ax

Check warning on line 71 in rl4co/envs/scheduling/fjsp/render.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/scheduling/fjsp/render.py#L71

Added line #L71 was not covered by tests
2 changes: 1 addition & 1 deletion rl4co/envs/scheduling/smtwtp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,4 @@

@staticmethod
def render(td, actions=None, ax=None):
raise render(td, actions, ax)
return render(td, actions, ax)

Check warning on line 198 in rl4co/envs/scheduling/smtwtp/env.py

View check run for this annotation

Codecov / codecov/patch

rl4co/envs/scheduling/smtwtp/env.py#L198

Added line #L198 was not covered by tests