Skip to content

Commit

Permalink
added RL
Browse files Browse the repository at this point in the history
  • Loading branch information
arnab39 committed Jul 31, 2024
1 parent ee039cf commit 13178a6
Show file tree
Hide file tree
Showing 13 changed files with 512 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/images/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def training_step(self, batch: torch.Tensor):
if self.hyperparams.experiment.training.loss.automated_prior:
def metric_function(model_predictions, targets):
return -F.cross_entropy(model_predictions, targets, reduction='none')
prior = self.canonicalizer.get_prior(x, self.prediction_network, y, metric_function, tau=0.1)
prior = self.canonicalizer.get_prior(x, self.prediction_network, y, metric_function, tau=0.01)
prior_loss = self.canonicalizer.get_prior_regularization_loss(prior) # type: ignore
else:
prior_loss = self.canonicalizer.get_prior_regularization_loss()
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
canonicalization_type: group_equivariant
network_type: e2cnn # Options for canonization method 1) e2cnn 2) custom 3) none
network_hyperparams:
kernel_size: 5 # Kernel size for the canonization network
out_channels: 32 # Number of output channels for the canonization network
num_layers: 3 # Number of layers in the canonization network
group_type: rotation # Type of group for the canonization network
num_rotations: 4 # Number of rotations for the canonization network
beta: 1.0 # Beta parameter for the canonization network
input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization
resize_shape: 64 # Resize shape for the input
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
canonicalization_type: identity
23 changes: 23 additions & 0 deletions examples/images/reinforcementlearning/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
canonicalization_type: group_equivariant # will be set later in training script
device: cuda

# metadata specialised for each experiment
core:
version: 0.0.1
tags:
- ${now:%Y-%m-%d}

hydra:
run:
dir: ${oc.env:HYDRA_JOBS}/singlerun/${now:%Y-%m-%d}/

sweep:
dir: ${oc.env:HYDRA_JOBS}/multirun/${now:%Y-%m-%d}/
subdir: ${hydra.job.num}_${hydra.job.id}

defaults:
- _self_
- env: default
- experiment: default
- canonicalization: identity
- wandb: default
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
learning_rate: 0.001
batch_size: 128
gamma: 0.999
eps_start: 0.9
eps_end: 0.01
eps_decay: 3000
target_update: 10
replay_memory_size: 100000
end_score: 200
training_stop: 142
num_episodes: 50000
last_episodes_num: 20
Empty file.
67 changes: 67 additions & 0 deletions examples/images/reinforcementlearning/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

class DQN(nn.Module):
def __init__(self, input_shape, num_actions, dueling_DQN=False):
super(DQN, self).__init__()

self.input_shape = input_shape
self.num_actions = num_actions
self.dueling_DQN = dueling_DQN

self.features = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=5, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=5, stride=2),
nn.BatchNorm2d(64),
nn.ReLU()
)

feature_size = self._get_feature_size()

if self.dueling_DQN:
self.advantage = nn.Sequential(
nn.Linear(feature_size, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, self.num_actions)
)
self.value = nn.Sequential(
nn.Linear(feature_size, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 1)
)
else:
self.action_value = nn.Sequential(
nn.Linear(feature_size, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, self.num_actions)
)

def forward(self, x):
x = x.float() / 255 # Normalize the input
x = self.features(x)
x = x.view(x.size(0), -1)

if self.dueling_DQN:
advantage = self.advantage(x)
value = self.value(x)
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
else:
q_values = self.action_value(x)

return q_values

def _get_feature_size(self):
self.features.eval()
with torch.no_grad():
return self.features(torch.zeros(1, *self.input_shape)).view(1, -1).size(1)

Empty file.
94 changes: 94 additions & 0 deletions examples/images/reinforcementlearning/prepare/gym_cartpole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
import gym

class CartpoleWrapper(gym.Wrapper):
def __init__(self, env, env_hyperparams):
"""Initialize the wrapper for the CartPole environment to preprocess images.
Args:
env (gym.Env): The Gym environment to wrap.
env_hyperparams (dict): Dictionary containing settings for image preprocessing,
such as resize dimensions and whether to apply grayscale.
"""
super().__init__(env)
self.env = env
self.num_actions = env.action_space.n
self.state_shape = env.observation_space.shape

# Base transformations that are always applied
transformations = [
T.ToPILImage(),
T.Resize(env_hyperparams["resize_pixels"],
interpolation=Image.BICUBIC),
]

# Conditional grayscale transformation
if env_hyperparams["grayscale"]:
transformations.append(T.Grayscale())

# Final transformation to tensor
transformations.append(T.ToTensor())

# Compose all transformations into a single callable object
self.resize = T.Compose(transformations)

def get_cart_location(self, screen_width):
"""Calculate the cart's location on the screen for cropping.
Args:
screen_width (int): The width of the screen from the environment.
Returns:
int: The pixel location of the center of the cart.
"""
world_width = self.env.x_threshold * 2
scale = screen_width / world_width
return int(self.env.state[0] * scale + screen_width / 2.0) # Middle of the cart

def get_screen(self):
"""Capture, process, and crop the environment's screen.
Transforms the screen into a format suitable for input to a neural network:
crops, downsamples, converts to grayscale, and rescales.
Returns:
torch.Tensor: The processed screen tensor ready for model input.
"""
# Capture screen from the environment
screen = self.env.render().transpose((2, 0, 1)) # CHW format
_, screen_height, screen_width = screen.shape

# Crop the vertical dimension to focus on the main area of interest
screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)]

# Define the width of the cropped area around the cart
view_width = int(screen_width * 0.6)
cart_location = self.get_cart_location(screen_width)

# Calculate the horizontal slice range to center crop around the cart
if cart_location < view_width // 2:
slice_range = slice(view_width)
elif cart_location > (screen_width - view_width // 2):
slice_range = slice(-view_width, None)
else:
slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)

# Apply the calculated slice to crop horizontally
screen = screen[:, :, slice_range]

# Normalize, convert to tensor, resize, and add a batch dimension
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255.
screen = torch.from_numpy(screen)
return self.resize(screen).unsqueeze(0)

def step(self, action):
"""Apply an action to the environment, returning the processed screen, reward, done, and info."""
return self.env.step(action)

def reset(self):
"""Reset the environment and return the initial processed screen."""
self.env.reset()

Loading

0 comments on commit 13178a6

Please sign in to comment.