-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Major environment refactoring (draft version) #166
Conversation
if batch_size is None: | ||
batch_size = self.batch_size if td is None else td.batch_size | ||
if td is None or td.is_empty(): | ||
td = self.generator(batch_size=batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we include the generator as a parameter in the base environment already and set it in __init__()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just that we're calling it here (and further below), but in the base environment it doesn't actually exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the reason why they are here is because they could get passed to the environment itself, as in TorchRL here. This is the function signature for EnvBase
in TorchRL:
def __init__(
self,
*,
device: DEVICE_TYPING = None,
batch_size: Optional[torch.Size] = None,
run_type_checks: bool = False,
allow_done_after_reset: bool = False,
):
```
So I guess we should make the above explicit in `RL4COEnvBase` since ours is a child class!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngastzepeda I think it makes sense that now all environments have the generator
and generator_params
as the input in the __init__()
function. We could move them to the RL4COEnvBase()
.
And also as @fedebotu said, it's better to show other useful parameters from torchrl.EnvBase
in our RL4COEnvBase()
. Easier for users to know the provided APIs not only from the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngastzepeda I rethink about adding generator
and generator_params
to the RL4COEnvBase
class. I prefer to keep them for each environment at the current moment for two reasons:
-
We want users can initialize the environment with simply calling
env = <EnvName>()
, e.g.env = TSPEnv()
without any parameters. In this case, there is a required generator initialization for each environment with respective generator class, e.g.rl4co/rl4co/envs/routing/cvrp/env.py
Lines 59 to 61 in a9943c9
if generator is None: generator = CVRPGenerator(**generator_params) self.generator = generator
It would be hard or at least massive for users to understand if we implement this part in the base class; -
Putting the generator initializing for each environment could be a hint for users to understand the "generate data" -> "reset instance as a tensordict" -> "step rollout, ..." pipeline.
What do you think about this? 🤔
assert kwargs.get("mean_"+val_name, None) is not None, "mean is required for Normal distribution" | ||
assert kwargs.get(val_name+"_std", None) is not None, "std is required for Normal distribution" | ||
return Normal(mean=kwargs[val_name+"_mean"], std=kwargs[val_name+"_std"]) | ||
elif distribution == Exponential or distribution == "exponential": | ||
assert kwargs.get(val_name+"_rate", None) is not None, "rate is required for Exponential/Poisson distribution" | ||
return Exponential(rate=kwargs[val_name+"_rate"]) | ||
elif distribution == Poisson or distribution == "poisson": | ||
assert kwargs.get(val_name+"_rate", None) is not None, "rate is required for Exponential/Poisson distribution" | ||
return Poisson(rate=kwargs[val_name+"_rate"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly we're assuming a specific format for these parameters then (i.e. we expect parameters val_name_mean
, val_name_std
, etc.?) and it's not enough to simply pass f.e. mean = 5, std = 2
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! Your understand is correct.I have thought about this for some time. The thing is: this get_sampler()
function will be called in the generator multiple times for different features, e.g. in the OPGenerator()
self.depot_sampler = get_sampler("depot", depot_distribution, min_loc, max_loc, **kwargs)
self.prize_sampler = get_sampler("prize", prize_distribution, min_prize, max_prize, **kwargs)
If the user wants to init the location with a Normal
distribution and the prize with a Poisson
distribution, 3 parameters are required:
- The mean of location;
- The std of location;
- The rate of prize.
We have two options to consider these parameters in the OPGenerator()
:
- Adding all of them explicitly to the
__init__()
inputs;
def __init__(self, min_loc, max_loc, mean_loc, std_loc, rate_loc, loc_distribution,\
min_prize, max_prize, mean_prize, std_prize, rate_prize, prize_distribution)
- Supported by the
kwargs
in the__init__()
inputs, i.e. the user should follow the rule that: if you want to use theNormal
distribution for<val_name>
, you have to give extra parameters with exact the namemean_<val_name>
andstd_<val_name>
.
Actually, both will work, but for clarity and flexibility, I chose the second way. However, I understand that this would be confusing for users, we should have a clear documentation for the standard rule for the parameter name.
If you have a better implementation, please tell me 🤔 I do think the current implementation may not be the optimal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I commented at first that instead of doing the same things in every single environment (specifically getting the visited
,current_node
, done
tensors, etc.) we should define that in the parent class function such that the child classes can simply call the parent's class method. Then I noticed that the base class is for all environments, not just routing. Maybe it would make sense, though, to have a base class for routing (and one for scheduling, etc.), even within the base.py
file, which would inherit from RL4COEnvBase
and could define these things that are the same for all routing envs so we don't have to do the same thing multiple times...
Apart from that I only left a few minor comments :)
@@ -1,6 +1,6 @@ | |||
from rl4co.utils.pylogger import get_pylogger | |||
|
|||
from .pctsp import PCTSPEnv | |||
from ..pctsp.env import PCTSPEnv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the only difference between this class and PCTSP
seems to be that this one is stochastic, but there is no additional logic implemented to PCTSP
, why even have two separate environments and not just differentiate via the stochastic
boolean parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree but conceptually they are a bit different, so it might be worth to keep the difference. Technically you could call PCTSP with the stochastic
parameter on too
rl4co/envs/routing/svrp/env.py
Outdated
visited = td["visited"].scatter( | ||
-1, current_node.expand_as(td["action_mask"]), 1 | ||
) | ||
print(current_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, I had forgotten to delete the print statement^^
log = get_pylogger(__name__) | ||
|
||
|
||
class SVRPEnv(RL4COEnvBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we`re doing the refactoring anyways, we might as well rename this environment to SkillVRP to avoid confusion with the Stochastich VRP :)
log = get_pylogger(__name__) | ||
|
||
|
||
class SVRPGenerator(Generator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to SkillVRPGenerator
log = get_pylogger(__name__) | ||
|
||
|
||
def render(td, actions=None, ax=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to admit, I've never actually rendered a Skill VRP problem, so no idea if this runs without problems
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job!
Left some comments here and there. Additionally as @hyeok9855 is doing, there should be an additional (optional) file called local_search.py
dms[..., torch.arange(self.num_loc), torch.arange(self.num_loc)] = 0 | ||
|
||
log.info("Using TMAT class (triangle inequality): {}".format(self.tmat_class)) | ||
if self.tmat_class: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be inside of the sampler itself?
@@ -1,6 +1,6 @@ | |||
from rl4co.utils.pylogger import get_pylogger | |||
|
|||
from .pctsp import PCTSPEnv | |||
from ..pctsp.env import PCTSPEnv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree but conceptually they are a bit different, so it might be worth to keep the difference. Technically you could call PCTSP with the stochastic
parameter on too
tests/test_policy.py
Outdated
@@ -8,7 +8,8 @@ | |||
# Main autorergressive policy: rollout over multiple envs since it is the base | |||
@pytest.mark.parametrize( | |||
"env_name", | |||
["tsp", "cvrp", "sdvrp", "mtsp", "op", "pctsp", "spctsp", "dpp", "mdpp", "smtwtp"], | |||
# ["tsp", "cvrp", "sdvrp", "mtsp", "op", "pctsp", "spctsp", "dpp", "mdpp", "smtwtp"], | |||
["tsp", "cvrp", "sdvrp", "mtsp", "op", "pctsp", "spctsp"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why were tests from the above environments removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the current refactoring version I just finished the routing environments part. Since we modified the RL4COEnvBase()
, this will affect the running for EDA environments.
I will finish the refactoring for EDA environments in the coming commits, and put these checks back, don't worry.
Let's remember also to fix the shifts in the |
Notice that we moved most of the above in here #169 (without modification to environment logic or variables)! We will address the comments and merge soon~ |
There have been too many changes to track recently, and it seems that several features have already been added. I will be closing this for now and come back to this for a fresh PR if needed! |
Important
The merge of this pull request is postponed because it contains sensitive modifications to the environment logic, which may cause hidden bugs. We should be careful to update them. Therefore, this full version of environment refactoring will be kept as a draft. We opened another base version refactor pull request: #169, which only touches the environment structure and adds the generator without changing any logic for a safe refactor in the current state. In the future, we will based on this draft's full version, go further refactor environments step by step.
Description
Together with Major modeling refactoring #165, this PR is for major, long-due refactoring to the RL4CO environments codebase.
Motivation and Context
This refactoring is driven by following motivations:
Changelog
Environment Structure Refactoring
The refactored structure for environments is as following:
We have restructured the organization of the environment files for improved modularity and clarity. Each environment has its own directory, comprising three components:
env.py
: The core framework of the environment, managing functions such as_reset()
,_step()
, and others. For a comprehensive understanding, please refer to the documentation.generator.py
: Replace the previousgenerate_data()
function; this module works for randomly initializing instances within the environment. The updated version now supports custom data distributions. See the following sections for more details.render.py
: For visualization of the solution. Its separation from the main environment file enhances overall code readability.Data Generator Supporting
Each environment generator will be based on the base
Generator()
class with the following functions:__init_()
will record all the environment instance initialize parameters, for example,num_loc
,min_loc
,max_loc
, etc.Thus, you will see how the
__init__()
function for the environment (e.g.CVRPEnv.__init__(...)
) only takesgenerator
andgenerator_params
as input. Now, the environment initialize example would beVarious samplers will be initialized here. We provide the
get_sampler()
function to based on the input variables to return atorch.distributions
class. By default, we support distributionsUniform
,Normal
,Exponential
, andPoisson
for locations andcenter
,corner
, for depots. You can also pass your won distribution sampler. See the following sections for more details.__call__()
is a middle wrapper; at the moment, it is used to regularize thebatch_size
format supported by the TorchRL (i.e., in alist
format). Note that in this refactor version, we would finalize the dimension ofbatch_size
to be 1 for easier implementation and clearer understanding since even multi-batch-size dimensions can be easily transferred to a single dimension.__generate()
is the part you would like to implement for your own environment data generator.New
get_sampler()
functionThis implementation mainly refers to @ngastzepeda's code. In the current version, we support the following distributions:
center
: For depots. All depots will be initialized in the center of the space.corner
: For depots. All depots will be initialized in the bottom left corner of the space.Uniform
: Takesmin_val
andmax_val
as input.Exponential
andPoisson
: Takemean_val
andstd_val
as input.You can also use your own
Callable
function as the sampler. This function will take thebatch_size: List[int]
as input and return the sampledtorch.Tensor
.Modification for
RL4COEnvBase()
We move the checking for
batch_size
anddevice
from every environment to the base class for clarity, as shown inrl4co/rl4co/envs/common/base.py
Lines 130 to 138 in b70566b
We added a new
_get_reward()
function aside from the originalget_reward()
function and moved thecheck_solution_validity()
from every environment to the base class for clarity, as shown inrl4co/rl4co/envs/common/base.py
Lines 175 to 187 in b70566b
Standardization
We standardize the contents of
env.py
with the following functions:The order is considered to be natural and easy to follow, and we expected all environments to follow the same order for easier reference and matinees. In more detail, we have the following standardization:
available
tovisited
for more intuitive understanding. In thestep()
andget_action_mask()
calculation,visited
records which nodes are visited, and theaction_mask
is based on it with environment constraints (e.g., capacity, time window, etc.). Separating these two variables would be clearer for the calculation logic._step()
function to a nonstatic method. Follow the TorchRL style.get_action_mask()
calculation logic, which generally contains three parts: (a) initialize theaction_mask
based onvisited
; (b) update citiesaction_mask
based on the state; (c) update the depotaction_mask
finally. Based on experience, this logic would cause fewer conflicts and mass.i
,capacity,
used_capacity,
etc., are initialized with the size of[*batch_size, 1]
instead of[*batch_size, ]
. The reason is that in many masking operations, we need to do logic calculations between this 1-D feature and 2-D features, e.g., capacity with demand. Also, stay consistent with TorchRL implementation.num_loc
,min_loc
,max_loc
) to the generator for clarity.cost
variable to theget_reward
function for an intuitive understanding. In this case, the return (reward) is-cost
.Other Fixes
vehicle_capacity
→capacity
,capacity
→unnorm_capacity
to clarify.demand
variable will also contain the depot. For example, in the previousCVRPEnv()
, givennum_loc=50
, thetd[”locs”]
has the size of[batch_size, 51, 2]
(with the depot), and thetd[”demand”]
has the size of[batch_size, 50, 2]
. This causes index shifting in theget_action_mask()
function, which requires a few padding operations.0
→1e-5
), for example, in SDVRPdone = ~(demand > 0).any(-1)
→done = ~(demand > 1e-5).any(-1)
for better robustness to avoid edge cases.num_loc,
e.g., CVRPCAPACITIES,
if the givennum_loc
is not in the table, we will find the closestnum_loc
as replace and raise a warning to increase the running robustness.get_reward()
.Notes
num_depot
,num_agents
. These values are initialized bytorch.randint()
.0
to the start and end.Here is the summary of the refractory status for each environment:
env.py
,generator.py
,render.py
; fix the__init__()
and_reset()
functions;check_solution_validity()
function;_step()
andget_action_maks()
function are cleaned up with the standard pipeline.Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Thanks, and need your help
Thanks for @ngastzepeda's base code for this refactoring!
If you have time, welcome to provide your ideas/feedback on this PR.
CC: @Furffico @henry-yeh @bokveizen @LTluttmann
There are quite a few remaining works for this PR, and I will actively update them here.