-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
512 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
11 changes: 11 additions & 0 deletions
11
examples/images/reinforcementlearning/configs/canonicalization/group_equivariant.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
canonicalization_type: group_equivariant | ||
network_type: e2cnn # Options for canonization method 1) e2cnn 2) custom 3) none | ||
network_hyperparams: | ||
kernel_size: 5 # Kernel size for the canonization network | ||
out_channels: 32 # Number of output channels for the canonization network | ||
num_layers: 3 # Number of layers in the canonization network | ||
group_type: rotation # Type of group for the canonization network | ||
num_rotations: 4 # Number of rotations for the canonization network | ||
beta: 1.0 # Beta parameter for the canonization network | ||
input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization | ||
resize_shape: 64 # Resize shape for the input |
1 change: 1 addition & 0 deletions
1
examples/images/reinforcementlearning/configs/canonicalization/identity.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
canonicalization_type: identity |
23 changes: 23 additions & 0 deletions
23
examples/images/reinforcementlearning/configs/default.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
canonicalization_type: group_equivariant # will be set later in training script | ||
device: cuda | ||
|
||
# metadata specialised for each experiment | ||
core: | ||
version: 0.0.1 | ||
tags: | ||
- ${now:%Y-%m-%d} | ||
|
||
hydra: | ||
run: | ||
dir: ${oc.env:HYDRA_JOBS}/singlerun/${now:%Y-%m-%d}/ | ||
|
||
sweep: | ||
dir: ${oc.env:HYDRA_JOBS}/multirun/${now:%Y-%m-%d}/ | ||
subdir: ${hydra.job.num}_${hydra.job.id} | ||
|
||
defaults: | ||
- _self_ | ||
- env: default | ||
- experiment: default | ||
- canonicalization: identity | ||
- wandb: default |
12 changes: 12 additions & 0 deletions
12
examples/images/reinforcementlearning/configs/experiment/default.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
learning_rate: 0.001 | ||
batch_size: 128 | ||
gamma: 0.999 | ||
eps_start: 0.9 | ||
eps_end: 0.01 | ||
eps_decay: 3000 | ||
target_update: 10 | ||
replay_memory_size: 100000 | ||
end_score: 200 | ||
training_stop: 142 | ||
num_episodes: 50000 | ||
last_episodes_num: 20 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import random | ||
|
||
class DQN(nn.Module): | ||
def __init__(self, input_shape, num_actions, dueling_DQN=False): | ||
super(DQN, self).__init__() | ||
|
||
self.input_shape = input_shape | ||
self.num_actions = num_actions | ||
self.dueling_DQN = dueling_DQN | ||
|
||
self.features = nn.Sequential( | ||
nn.Conv2d(input_shape[0], 32, kernel_size=5, stride=2), | ||
nn.BatchNorm2d(32), | ||
nn.ReLU(), | ||
nn.Conv2d(32, 64, kernel_size=5, stride=2), | ||
nn.BatchNorm2d(64), | ||
nn.ReLU(), | ||
nn.Conv2d(64, 64, kernel_size=5, stride=2), | ||
nn.BatchNorm2d(64), | ||
nn.ReLU() | ||
) | ||
|
||
feature_size = self._get_feature_size() | ||
|
||
if self.dueling_DQN: | ||
self.advantage = nn.Sequential( | ||
nn.Linear(feature_size, 512), | ||
nn.BatchNorm1d(512), | ||
nn.ReLU(), | ||
nn.Linear(512, self.num_actions) | ||
) | ||
self.value = nn.Sequential( | ||
nn.Linear(feature_size, 512), | ||
nn.BatchNorm1d(512), | ||
nn.ReLU(), | ||
nn.Linear(512, 1) | ||
) | ||
else: | ||
self.action_value = nn.Sequential( | ||
nn.Linear(feature_size, 512), | ||
nn.BatchNorm1d(512), | ||
nn.ReLU(), | ||
nn.Linear(512, self.num_actions) | ||
) | ||
|
||
def forward(self, x): | ||
x = x.float() / 255 # Normalize the input | ||
x = self.features(x) | ||
x = x.view(x.size(0), -1) | ||
|
||
if self.dueling_DQN: | ||
advantage = self.advantage(x) | ||
value = self.value(x) | ||
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True)) | ||
else: | ||
q_values = self.action_value(x) | ||
|
||
return q_values | ||
|
||
def _get_feature_size(self): | ||
self.features.eval() | ||
with torch.no_grad(): | ||
return self.features(torch.zeros(1, *self.input_shape)).view(1, -1).size(1) | ||
|
Empty file.
94 changes: 94 additions & 0 deletions
94
examples/images/reinforcementlearning/prepare/gym_cartpole.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import numpy as np | ||
import torch | ||
import torchvision.transforms as T | ||
from PIL import Image | ||
import gym | ||
|
||
class CartpoleWrapper(gym.Wrapper): | ||
def __init__(self, env, env_hyperparams): | ||
"""Initialize the wrapper for the CartPole environment to preprocess images. | ||
Args: | ||
env (gym.Env): The Gym environment to wrap. | ||
env_hyperparams (dict): Dictionary containing settings for image preprocessing, | ||
such as resize dimensions and whether to apply grayscale. | ||
""" | ||
super().__init__(env) | ||
self.env = env | ||
self.num_actions = env.action_space.n | ||
self.state_shape = env.observation_space.shape | ||
|
||
# Base transformations that are always applied | ||
transformations = [ | ||
T.ToPILImage(), | ||
T.Resize(env_hyperparams["resize_pixels"], | ||
interpolation=Image.BICUBIC), | ||
] | ||
|
||
# Conditional grayscale transformation | ||
if env_hyperparams["grayscale"]: | ||
transformations.append(T.Grayscale()) | ||
|
||
# Final transformation to tensor | ||
transformations.append(T.ToTensor()) | ||
|
||
# Compose all transformations into a single callable object | ||
self.resize = T.Compose(transformations) | ||
|
||
def get_cart_location(self, screen_width): | ||
"""Calculate the cart's location on the screen for cropping. | ||
Args: | ||
screen_width (int): The width of the screen from the environment. | ||
Returns: | ||
int: The pixel location of the center of the cart. | ||
""" | ||
world_width = self.env.x_threshold * 2 | ||
scale = screen_width / world_width | ||
return int(self.env.state[0] * scale + screen_width / 2.0) # Middle of the cart | ||
|
||
def get_screen(self): | ||
"""Capture, process, and crop the environment's screen. | ||
Transforms the screen into a format suitable for input to a neural network: | ||
crops, downsamples, converts to grayscale, and rescales. | ||
Returns: | ||
torch.Tensor: The processed screen tensor ready for model input. | ||
""" | ||
# Capture screen from the environment | ||
screen = self.env.render().transpose((2, 0, 1)) # CHW format | ||
_, screen_height, screen_width = screen.shape | ||
|
||
# Crop the vertical dimension to focus on the main area of interest | ||
screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)] | ||
|
||
# Define the width of the cropped area around the cart | ||
view_width = int(screen_width * 0.6) | ||
cart_location = self.get_cart_location(screen_width) | ||
|
||
# Calculate the horizontal slice range to center crop around the cart | ||
if cart_location < view_width // 2: | ||
slice_range = slice(view_width) | ||
elif cart_location > (screen_width - view_width // 2): | ||
slice_range = slice(-view_width, None) | ||
else: | ||
slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2) | ||
|
||
# Apply the calculated slice to crop horizontally | ||
screen = screen[:, :, slice_range] | ||
|
||
# Normalize, convert to tensor, resize, and add a batch dimension | ||
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255. | ||
screen = torch.from_numpy(screen) | ||
return self.resize(screen).unsqueeze(0) | ||
|
||
def step(self, action): | ||
"""Apply an action to the environment, returning the processed screen, reward, done, and info.""" | ||
return self.env.step(action) | ||
|
||
def reset(self): | ||
"""Reset the environment and return the initial processed screen.""" | ||
self.env.reset() | ||
|
Oops, something went wrong.