From 3a50954039843fd2b1171e8ccbf16834d69b599f Mon Sep 17 00:00:00 2001
From: Aravind Rajeswaran <aravraj@fb.com>
Date: Sun, 13 Feb 2022 23:06:47 -0800
Subject: [PATCH 01/13] Using concurrent futures for parallelizing rollouts

---
 mjrl/samplers/core.py      | 29 ++++++++++++++++++++++++++---
 mjrl/utils/tensor_utils.py |  2 +-
 2 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/mjrl/samplers/core.py b/mjrl/samplers/core.py
index be4a988..8c267c5 100644
--- a/mjrl/samplers/core.py
+++ b/mjrl/samplers/core.py
@@ -6,6 +6,7 @@
 import multiprocessing as mp
 import time as timer
 logging.disable(logging.CRITICAL)
+import gc
 
 
 # Single core rollout to sample trajectories
@@ -93,6 +94,7 @@ def do_rollout(
         paths.append(path)
 
     del(env)
+    gc.collect()
     return paths
 
 
@@ -134,7 +136,7 @@ def sample_paths(
         start_time = timer.time()
         print("####### Gathering Samples #######")
 
-    results = _try_multiprocess(do_rollout, input_dict_list,
+    results = _try_multiprocess_cf(do_rollout, input_dict_list,
                                 num_cpu, max_process_time, max_timeouts)
     paths = []
     # result is a paths type and results is list of paths
@@ -186,7 +188,7 @@ def sample_data_batch(
     return paths
 
 
-def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
+def _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
     
     # Base case
     if max_timeouts == 0:
@@ -202,9 +204,30 @@ def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_time
         pool.close()
         pool.terminate()
         pool.join()
-        return _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1)
+        return _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1)
 
     pool.close()
     pool.terminate()
     pool.join()  
     return results
+
+
+def _try_multiprocess_cf(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
+    import concurrent.futures
+    results = None
+    if max_timeouts != 0:
+        with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpu) as executor:
+            submit_futures = [executor.submit(func, **input_dict) for input_dict in input_dict_list]
+            try:
+                results = [f.result() for f in submit_futures]
+            except TimeoutError as e:
+                print(str(e))
+                print("Timeout Error raised...") 
+            except concurrent.futures.CancelledError as e:
+                print(str(e))
+                print("Future Cancelled Error raised...") 
+            except Exception as e:
+                print(str(e))
+                print("Error raised...") 
+                raise e
+    return results
\ No newline at end of file
diff --git a/mjrl/utils/tensor_utils.py b/mjrl/utils/tensor_utils.py
index 8b0002a..fc7b90f 100644
--- a/mjrl/utils/tensor_utils.py
+++ b/mjrl/utils/tensor_utils.py
@@ -61,7 +61,7 @@ def high_res_normalize(probs):
 
 
 def stack_tensor_list(tensor_list):
-    return np.array(tensor_list)
+    return np.array(tensor_list, dtype=object)
     # tensor_shape = np.array(tensor_list[0]).shape
     # if tensor_shape is tuple():
     #     return np.array(tensor_list)

From 74e3df32217f223bec3b1a74a907c9c1f3f0cf93 Mon Sep 17 00:00:00 2001
From: Aravind Rajeswaran <aravraj@fb.com>
Date: Sun, 13 Feb 2022 23:07:31 -0800
Subject: [PATCH 02/13] Adding BatchNormMLP policy class

---
 mjrl/algos/behavior_cloning.py |  2 ++
 mjrl/policies/gaussian_mlp.py  | 61 +++++++++++++++++++++++++++++++++-
 mjrl/utils/fc_network.py       | 60 +++++++++++++++++++++++++++++----
 3 files changed, 116 insertions(+), 7 deletions(-)

diff --git a/mjrl/algos/behavior_cloning.py b/mjrl/algos/behavior_cloning.py
index 6aac09e..bc1d512 100644
--- a/mjrl/algos/behavior_cloning.py
+++ b/mjrl/algos/behavior_cloning.py
@@ -118,6 +118,7 @@ def fit(self, data, suppress_fit_tqdm=False, **kwargs):
             self.logger.log_kv('loss_before', loss_val)
 
         # train loop
+        self.policy.model.train()
         for ep in config_tqdm(range(self.epochs), suppress_fit_tqdm):
             for mb in range(int(num_samples / self.mb_size)):
                 rand_idx = np.random.choice(num_samples, size=self.mb_size)
@@ -125,6 +126,7 @@ def fit(self, data, suppress_fit_tqdm=False, **kwargs):
                 loss = self.loss(data, idx=rand_idx)
                 loss.backward()
                 self.optimizer.step()
+        self.policy.model.eval()
         params_after_opt = self.policy.get_param_values()
         self.policy.set_param_values(params_after_opt, set_new=True, set_old=True)
 
diff --git a/mjrl/policies/gaussian_mlp.py b/mjrl/policies/gaussian_mlp.py
index ae97145..27293fe 100644
--- a/mjrl/policies/gaussian_mlp.py
+++ b/mjrl/policies/gaussian_mlp.py
@@ -1,5 +1,5 @@
 import numpy as np
-from mjrl.utils.fc_network import FCNetwork
+from mjrl.utils.fc_network import FCNetwork, FCNetworkWithBatchNorm
 import torch
 from torch.autograd import Variable
 
@@ -143,3 +143,62 @@ def mean_kl(self, new_dist_info, old_dist_info):
         Dr = 2 * new_std ** 2 + 1e-8
         sample_kl = torch.sum(Nr / Dr + new_log_std - old_log_std, dim=1)
         return torch.mean(sample_kl)
+
+
+class BatchNormMLP(MLP):
+    def __init__(self, env_spec,
+                 hidden_sizes=(64,64),
+                 min_log_std=-3,
+                 init_log_std=0,
+                 seed=None,
+                 nonlinearity='relu',
+                 *args, **kwargs,
+                 ):
+        """
+        :param env_spec: specifications of the env (see utils/gym_env.py)
+        :param hidden_sizes: network hidden layer sizes (currently 2 layers only)
+        :param min_log_std: log_std is clamped at this value and can't go below
+        :param init_log_std: initial log standard deviation
+        :param seed: random seed
+        """
+        # super(BatchNormMLP, self).__init__()
+
+        self.n = env_spec.observation_dim  # number of states
+        self.m = env_spec.action_dim  # number of actions
+        self.min_log_std = min_log_std
+
+        # Set seed
+        # ------------------------
+        if seed is not None:
+            torch.manual_seed(seed)
+            np.random.seed(seed)
+
+        # Policy network
+        # ------------------------
+        self.model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity)
+        # make weights small
+        for param in list(self.model.parameters())[-2:]:  # only last layer
+           param.data = 1e-2 * param.data
+        self.log_std = Variable(torch.ones(self.m) * init_log_std, requires_grad=True)
+        self.trainable_params = list(self.model.parameters()) + [self.log_std]
+        self.model.eval()
+
+        # Old Policy network
+        # ------------------------
+        self.old_model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity)
+        self.old_log_std = Variable(torch.ones(self.m) * init_log_std)
+        self.old_params = list(self.old_model.parameters()) + [self.old_log_std]
+        for idx, param in enumerate(self.old_params):
+            param.data = self.trainable_params[idx].data.clone()
+        self.old_model.eval()
+
+        # Easy access variables
+        # -------------------------
+        self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
+        self.param_shapes = [p.data.numpy().shape for p in self.trainable_params]
+        self.param_sizes = [p.data.numpy().size for p in self.trainable_params]
+        self.d = np.sum(self.param_sizes)  # total number of params
+
+        # Placeholders
+        # ------------------------
+        self.obs_var = Variable(torch.randn(self.n), requires_grad=False)
\ No newline at end of file
diff --git a/mjrl/utils/fc_network.py b/mjrl/utils/fc_network.py
index ea3ad72..1bb0382 100644
--- a/mjrl/utils/fc_network.py
+++ b/mjrl/utils/fc_network.py
@@ -37,12 +37,14 @@ def set_transformations(self, in_shift=None, in_scale=None, out_shift=None, out_
         self.out_scale = torch.from_numpy(np.float32(out_scale)) if out_scale is not None else torch.ones(self.act_dim)
 
     def forward(self, x):
-        # TODO(Aravind): Remove clamping to CPU
-        # This is a temp change that should be fixed shortly
-        if x.is_cuda:
-            out = x.to('cpu')
-        else:
-            out = x
+        try:
+            out = x.to(self.device)
+        except:
+            if hasattr(self, 'device') == False:
+                self.device = 'cpu'
+                out = x.to(self.device)
+            else:
+                raise TypeError
         out = (out - self.in_shift)/(self.in_scale + 1e-8)
         for i in range(len(self.fc_layers)-1):
             out = self.fc_layers[i](out)
@@ -50,3 +52,49 @@ def forward(self, x):
         out = self.fc_layers[-1](out)
         out = out * self.out_scale + self.out_shift
         return out
+
+    def to(self, device):
+        self.device = device
+        # change the transforms to the appropriate device
+        self.in_shift = self.in_shift.to(device)
+        self.in_scale = self.in_scale.to(device)
+        self.out_shift = self.out_shift.to(device)
+        self.out_scale = self.out_scale.to(device)
+        # move all other trainable parameters to device
+        super().to(device)
+
+
+class FCNetworkWithBatchNorm(nn.Module):
+    def __init__(self, obs_dim, act_dim,
+                 hidden_sizes=(64,64),
+                 nonlinearity='relu',   # either 'tanh' or 'relu'
+                ):
+        super(FCNetworkWithBatchNorm, self).__init__()
+
+        self.obs_dim = obs_dim
+        self.act_dim = act_dim
+        assert type(hidden_sizes) == tuple
+        self.layer_sizes = (obs_dim, ) + hidden_sizes + (act_dim, )
+        self.device = 'cpu'
+
+        # hidden layers
+        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1]) \
+                         for i in range(len(self.layer_sizes) -1)])
+        self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
+        self.input_batchnorm = nn.BatchNorm1d(num_features=obs_dim)
+
+    def forward(self, x):
+        out = x.to(self.device)
+        out = self.input_batchnorm(out)
+        for i in range(len(self.fc_layers)-1):
+            out = self.fc_layers[i](out)
+            out = self.nonlinearity(out)
+        out = self.fc_layers[-1](out)
+        return out
+
+    def to(self, device):
+        self.device = device
+        super().to(device)
+
+    def set_transformations(self, *args, **kwargs):
+        pass

From 5d3c4c5e02abdc633859313aace4b51e20556892 Mon Sep 17 00:00:00 2001
From: Aravind Rajeswaran <aravraj@fb.com>
Date: Fri, 18 Feb 2022 12:51:10 -0800
Subject: [PATCH 03/13] Adding dropout option to BatchNorm network

---
 mjrl/policies/gaussian_mlp.py | 5 +++--
 mjrl/utils/fc_network.py      | 4 ++++
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/mjrl/policies/gaussian_mlp.py b/mjrl/policies/gaussian_mlp.py
index 27293fe..a165bb7 100644
--- a/mjrl/policies/gaussian_mlp.py
+++ b/mjrl/policies/gaussian_mlp.py
@@ -152,6 +152,7 @@ def __init__(self, env_spec,
                  init_log_std=0,
                  seed=None,
                  nonlinearity='relu',
+                 dropout=0,
                  *args, **kwargs,
                  ):
         """
@@ -175,7 +176,7 @@ def __init__(self, env_spec,
 
         # Policy network
         # ------------------------
-        self.model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity)
+        self.model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout)
         # make weights small
         for param in list(self.model.parameters())[-2:]:  # only last layer
            param.data = 1e-2 * param.data
@@ -185,7 +186,7 @@ def __init__(self, env_spec,
 
         # Old Policy network
         # ------------------------
-        self.old_model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity)
+        self.old_model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout)
         self.old_log_std = Variable(torch.ones(self.m) * init_log_std)
         self.old_params = list(self.old_model.parameters()) + [self.old_log_std]
         for idx, param in enumerate(self.old_params):
diff --git a/mjrl/utils/fc_network.py b/mjrl/utils/fc_network.py
index 1bb0382..c93231c 100644
--- a/mjrl/utils/fc_network.py
+++ b/mjrl/utils/fc_network.py
@@ -68,6 +68,8 @@ class FCNetworkWithBatchNorm(nn.Module):
     def __init__(self, obs_dim, act_dim,
                  hidden_sizes=(64,64),
                  nonlinearity='relu',   # either 'tanh' or 'relu'
+                 dropout=0,           # probability to dropout activations (0 means no dropout)
+                 *args, **kwargs,
                 ):
         super(FCNetworkWithBatchNorm, self).__init__()
 
@@ -82,12 +84,14 @@ def __init__(self, obs_dim, act_dim,
                          for i in range(len(self.layer_sizes) -1)])
         self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
         self.input_batchnorm = nn.BatchNorm1d(num_features=obs_dim)
+        self.dropout = nn.Dropout(dropout)
 
     def forward(self, x):
         out = x.to(self.device)
         out = self.input_batchnorm(out)
         for i in range(len(self.fc_layers)-1):
             out = self.fc_layers[i](out)
+            out = self.dropout(out)
             out = self.nonlinearity(out)
         out = self.fc_layers[-1](out)
         return out

From 35c0e78979d2e8b3237481772698ac84c4c19a0e Mon Sep 17 00:00:00 2001
From: Aravind Rajeswaran <aravraj@fb.com>
Date: Fri, 18 Feb 2022 12:52:47 -0800
Subject: [PATCH 04/13] Targetting unwrapped env for low-level functionality
 like states and rendering

---
 mjrl/utils/gym_env.py | 36 ++++++++++++++++++------------------
 1 file changed, 18 insertions(+), 18 deletions(-)

diff --git a/mjrl/utils/gym_env.py b/mjrl/utils/gym_env.py
index eb45902..434b7e2 100644
--- a/mjrl/utils/gym_env.py
+++ b/mjrl/utils/gym_env.py
@@ -15,9 +15,9 @@ def __init__(self, obs_dim, act_dim, horizon):
 
 class GymEnv(object):
     def __init__(self, env, env_kwargs=None,
-                 obs_mask=None, act_repeat=1, 
+                 obs_mask=None, act_repeat=1,
                  *args, **kwargs):
-    
+
         # get the correct env behavior
         if type(env) == str:
             env = gym.make(env)
@@ -30,7 +30,7 @@ def __init__(self, env, env_kwargs=None,
             raise AttributeError
 
         self.env = env
-        self.env_id = env.spec.id
+        self.env_id = env.unwrapped.spec.id
         self.act_repeat = act_repeat
 
         try:
@@ -42,14 +42,14 @@ def __init__(self, env, env_kwargs=None,
         self._horizon = self._horizon // self.act_repeat
 
         try:
-            self._action_dim = self.env.env.action_dim
-        except AttributeError:
             self._action_dim = self.env.action_space.shape[0]
+        except AttributeError:
+            self._action_dim = self.env.unwrapped.action_dim
 
         try:
-            self._observation_dim = self.env.env.obs_dim
-        except AttributeError:
             self._observation_dim = self.env.observation_space.shape[0]
+        except AttributeError:
+            self._observation_dim = self.env.unwrapped.obs_dim
 
         # Specs
         self.spec = EnvSpec(self._observation_dim, self._action_dim, self._horizon)
@@ -80,7 +80,7 @@ def horizon(self):
     def reset(self, seed=None):
         try:
             self.env._elapsed_steps = 0
-            return self.env.env.reset_model(seed=seed)
+            return self.env.unwrapped.reset_model(seed=seed)
         except:
             if seed is not None:
                 self.set_seed(seed)
@@ -92,7 +92,7 @@ def reset_model(self, seed=None):
 
     def step(self, action):
         action = action.clip(self.action_space.low, self.action_space.high)
-        if self.act_repeat == 1: 
+        if self.act_repeat == 1:
             obs, cum_reward, done, ifo = self.env.step(action)
         else:
             cum_reward = 0.0
@@ -104,8 +104,8 @@ def step(self, action):
 
     def render(self):
         try:
-            self.env.env.mujoco_render_frames = True
-            self.env.env.mj_render()
+            self.env.unwrapped.mujoco_render_frames = True
+            self.env.unwrapped.mj_render()
         except:
             self.env.render()
 
@@ -117,13 +117,13 @@ def set_seed(self, seed=123):
 
     def get_obs(self):
         try:
-            return self.obs_mask * self.env.env.get_obs()
+            return self.obs_mask * self.env.get_obs()
         except:
-            return self.obs_mask * self.env.env._get_obs()
+            return self.obs_mask * self.env._get_obs()
 
     def get_env_infos(self):
         try:
-            return self.env.env.get_env_infos()
+            return self.env.unwrapped.get_env_infos()
         except:
             return {}
 
@@ -133,19 +133,19 @@ def get_env_infos(self):
 
     def get_env_state(self):
         try:
-            return self.env.env.get_env_state()
+            return self.env.unwrapped.get_env_state()
         except:
             raise NotImplementedError
 
     def set_env_state(self, state_dict):
         try:
-            self.env.env.set_env_state(state_dict)
+            self.env.unwrapped.set_env_state(state_dict)
         except:
             raise NotImplementedError
 
     def real_env_step(self, bool_val):
         try:
-            self.env.env.real_step = bool_val
+            self.env.unwrapped.real_step = bool_val
         except:
             raise NotImplementedError
 
@@ -153,7 +153,7 @@ def real_env_step(self, bool_val):
 
     def visualize_policy(self, policy, horizon=1000, num_episodes=1, mode='exploration'):
         try:
-            self.env.env.visualize_policy(policy, horizon, num_episodes, mode)
+            self.env.unwrapped.visualize_policy(policy, horizon, num_episodes, mode)
         except:
             for ep in range(num_episodes):
                 o = self.reset()

From 978b4bbb15d467f8baa4e4783f28eea397b47b43 Mon Sep 17 00:00:00 2001
From: Aravind Rajeswaran <aravraj@fb.com>
Date: Fri, 18 Feb 2022 13:10:31 -0800
Subject: [PATCH 05/13] Updating pytorch version

---
 setup/env.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/setup/env.yml b/setup/env.yml
index 5bc4397..04e4159 100644
--- a/setup/env.yml
+++ b/setup/env.yml
@@ -7,7 +7,7 @@ dependencies:
 - pip
 - ipython
 - mkl-service
-- pytorch==1.4
+- pytorch==1.9.0
 - tabulate
 - termcolor
 - torchvision

From 83d35df95eb64274c5e93bb32a0a4e2f6576638a Mon Sep 17 00:00:00 2001
From: Aravind Rajeswaran <rajeswaran.aravind@gmail.com>
Date: Tue, 31 May 2022 00:32:02 -0700
Subject: [PATCH 06/13] Adding wandb integration

---
 mjrl/utils/logger.py      | 22 ++++++++++++++++++++--
 mjrl/utils/train_agent.py | 10 +++++++---
 2 files changed, 27 insertions(+), 5 deletions(-)

diff --git a/mjrl/utils/logger.py b/mjrl/utils/logger.py
index 0a155ed..6494eb2 100644
--- a/mjrl/utils/logger.py
+++ b/mjrl/utils/logger.py
@@ -7,15 +7,26 @@
 import os
 import csv
 
+# Defaults
+USERNAME = 'aravraj'
+WANDB_PROJECT = 'mjrl_test'
+
 class DataLog:
 
-    def __init__(self):
+    def __init__(self, use_wandb:bool = True,
+                 wandb_user:str = USERNAME,
+                 wandb_project:str = WANDB_PROJECT,
+                 wandb_config:dict = dict()) -> None:
+        self.use_wandb = use_wandb
+        if use_wandb:
+            import wandb
+            self.run = wandb.init(project=wandb_project, entity=wandb_user, config=wandb_config)
         self.log = {}
         self.max_len = 0
+        self.global_step = 0
 
     def log_kv(self, key, value):
         # logs the (key, value) pair
-
         # TODO: This implementation is error-prone:
         # it would be NOT aligned if some keys are missing during one iteration.
         if key not in self.log:
@@ -23,6 +34,8 @@ def log_kv(self, key, value):
         self.log[key].append(value)
         if len(self.log[key]) > self.max_len:
             self.max_len = self.max_len + 1
+        if self.use_wandb:
+            self.run.log({key: value}, step=self.global_step)
 
     def save_log(self, save_path):
         # TODO: Validate all lengths are the same.
@@ -56,6 +69,11 @@ def shrink_to(self, num_entries):
         assert min([len(series) for series in self.log.values()]) == \
             max([len(series) for series in self.log.values()])
 
+    def sync_log_with_wandb(self):
+        # Syncs the latest logged entries with wandb
+        latest_log = self.get_current_log()
+        self.run.log(latest_log, step=self.global_step)
+
     def read_log(self, log_path):
         assert log_path.endswith('log.csv')
 
diff --git a/mjrl/utils/train_agent.py b/mjrl/utils/train_agent.py
index 688b638..651a894 100644
--- a/mjrl/utils/train_agent.py
+++ b/mjrl/utils/train_agent.py
@@ -1,6 +1,3 @@
-import logging
-logging.disable(logging.CRITICAL)
-
 from tabulate import tabulate
 from mjrl.utils.make_train_plots import make_train_plots
 from mjrl.utils.gym_env import GymEnv
@@ -94,9 +91,12 @@ def train_agent(job_name, agent,
     if i_start:
         print("Resuming from an existing job folder ...")
 
+    env_samples = 0
     for i in range(i_start, niter):
         print("......................................................................................")
         print("ITERATION : %i " % i)
+        if agent.logger.use_wandb:
+            agent.logger.global_step += 1
 
         if train_curve[i-1] > best_perf:
             best_policy = copy.deepcopy(agent.policy)
@@ -106,6 +106,10 @@ def train_agent(job_name, agent,
         args = dict(N=N, sample_mode=sample_mode, gamma=gamma, gae_lambda=gae_lambda, num_cpu=num_cpu)
         stats = agent.train_step(**args)
         train_curve[i] = stats[0]
+        # log total number of samples so far for convinience
+        iter_samples = agent.logger.get_current_log()['num_samples']
+        env_samples += iter_samples
+        agent.logger.log_kv('env_samples', env_samples)
 
         if evaluation_rollouts is not None and evaluation_rollouts > 0:
             print("Performing evaluation rollouts ........")

From b42c4bdede939d1936c124a2eaee0e6778024e6b Mon Sep 17 00:00:00 2001
From: Vikash Kumar <vikashplus@gmail.com>
Date: Wed, 17 May 2023 18:34:02 -0400
Subject: [PATCH 07/13] Users can now specify and pass wandb details from the
 config files

---
 examples/example_configs/hopper_npg.txt  | 6 ++++++
 examples/example_configs/swimmer_npg.txt | 6 ++++++
 examples/example_configs/swimmer_ppo.txt | 8 +++++++-
 examples/policy_opt_job_script.py        | 7 +++++++
 mjrl/utils/logger.py                     | 7 ++++++-
 5 files changed, 32 insertions(+), 2 deletions(-)

diff --git a/examples/example_configs/hopper_npg.txt b/examples/example_configs/hopper_npg.txt
index bd98381..bf16679 100644
--- a/examples/example_configs/hopper_npg.txt
+++ b/examples/example_configs/hopper_npg.txt
@@ -29,5 +29,11 @@
 
 'alg_hyper_params'  :   dict(),
 
+'wandb_params':  {
+    'use_wandb'     : True,
+    'wandb_user'    : 'vikashplus',
+    'wandb_project' : 'mjrl_demo',
+    'wandb_exp'     : 'demo_exp',
+    }
 }
 
diff --git a/examples/example_configs/swimmer_npg.txt b/examples/example_configs/swimmer_npg.txt
index f8af3a8..8c09008 100644
--- a/examples/example_configs/swimmer_npg.txt
+++ b/examples/example_configs/swimmer_npg.txt
@@ -29,4 +29,10 @@
 
 'alg_hyper_params'  :   dict(),
 
+'wandb_params':  {
+    'use_wandb'     : True,
+    'wandb_user'    : 'vikashplus',
+    'wandb_project' : 'mjrl_demo',
+    'wandb_exp'     : 'demo_exp',
+    }
 }
\ No newline at end of file
diff --git a/examples/example_configs/swimmer_ppo.txt b/examples/example_configs/swimmer_ppo.txt
index e20d592..6c561e0 100644
--- a/examples/example_configs/swimmer_ppo.txt
+++ b/examples/example_configs/swimmer_ppo.txt
@@ -7,7 +7,7 @@
 'seed'              :   123,
 'sample_mode'       :   'trajectories',
 'rl_num_traj'       :   10,
-'rl_num_iter'       :   50,
+'rl_num_iter'       :   10,
 'num_cpu'           :   2,
 'save_freq'         :   25,
 'eval_rollouts'     :   None,
@@ -29,4 +29,10 @@
 
 'alg_hyper_params'  :   dict(clip_coef=0.2, epochs=10, mb_size=64, learn_rate=5e-4),
 
+'wandb_params':  {
+    'use_wandb'     : True,
+    'wandb_user'    : 'vikashplus',
+    'wandb_project' : 'mjrl_demo',
+    'wandb_exp'     : 'demo_exp',
+    }
 }
\ No newline at end of file
diff --git a/examples/policy_opt_job_script.py b/examples/policy_opt_job_script.py
index 0ee68df..7253ed0 100644
--- a/examples/policy_opt_job_script.py
+++ b/examples/policy_opt_job_script.py
@@ -13,6 +13,7 @@
 from mjrl.algos.batch_reinforce import BatchREINFORCE
 from mjrl.algos.ppo_clip import PPO
 from mjrl.utils.train_agent import train_agent
+from mjrl.utils.logger import DataLog
 import os
 import json
 import gym
@@ -82,6 +83,12 @@
     # or defaults in the PPO algorithm will be used
     agent = PPO(e, policy, baseline, save_logs=True, **job_data['alg_hyper_params'])
 
+
+# Update logger if WandB in Config
+if 'wandb_params' in job_data.keys() and job_data['wandb_params']['use_wandb']==True:
+    agent.logger = DataLog(**job_data['wandb_params'], wandb_config=job_data)
+
+
 print("========================================")
 print("Starting policy learning")
 print("========================================")
diff --git a/mjrl/utils/logger.py b/mjrl/utils/logger.py
index 6494eb2..9fd4845 100644
--- a/mjrl/utils/logger.py
+++ b/mjrl/utils/logger.py
@@ -13,14 +13,19 @@
 
 class DataLog:
 
-    def __init__(self, use_wandb:bool = True,
+    def __init__(self,
+                 use_wandb:bool = False,
                  wandb_user:str = USERNAME,
                  wandb_project:str = WANDB_PROJECT,
+                 wandb_exp:str = None,
                  wandb_config:dict = dict()) -> None:
         self.use_wandb = use_wandb
         if use_wandb:
             import wandb
             self.run = wandb.init(project=wandb_project, entity=wandb_user, config=wandb_config)
+            # Update exp name if explicitely specified
+            if wandb_exp is not None: wandb.run.name = wandb_exp
+
         self.log = {}
         self.max_len = 0
         self.global_step = 0

From cc6af69aaef97d3db6e5f748bf7e4e3f6e4bb723 Mon Sep 17 00:00:00 2001
From: Vikash Kumar <vikashplus@gmail.com>
Date: Sun, 21 May 2023 00:17:54 -0400
Subject: [PATCH 08/13] FEATURE: directory can be specified for wandb logs

---
 examples/policy_opt_job_script.py | 4 ++++
 mjrl/utils/logger.py              | 3 ++-
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/examples/policy_opt_job_script.py b/examples/policy_opt_job_script.py
index 7253ed0..e0ae249 100644
--- a/examples/policy_opt_job_script.py
+++ b/examples/policy_opt_job_script.py
@@ -86,6 +86,10 @@
 
 # Update logger if WandB in Config
 if 'wandb_params' in job_data.keys() and job_data['wandb_params']['use_wandb']==True:
+    if 'wandb_logdir' in job_data['wandb_params']:
+        job_data['wandb_params']['wandb_logdir'] = os.path.join(JOB_DIR, job_data['wandb_params']['wandb_logdir'])
+    else:
+        job_data['wandb_params']['wandb_logdir'] = JOB_DIR
     agent.logger = DataLog(**job_data['wandb_params'], wandb_config=job_data)
 
 
diff --git a/mjrl/utils/logger.py b/mjrl/utils/logger.py
index 9fd4845..f96a074 100644
--- a/mjrl/utils/logger.py
+++ b/mjrl/utils/logger.py
@@ -18,11 +18,12 @@ def __init__(self,
                  wandb_user:str = USERNAME,
                  wandb_project:str = WANDB_PROJECT,
                  wandb_exp:str = None,
+                 wandb_logdir:str = None,
                  wandb_config:dict = dict()) -> None:
         self.use_wandb = use_wandb
         if use_wandb:
             import wandb
-            self.run = wandb.init(project=wandb_project, entity=wandb_user, config=wandb_config)
+            self.run = wandb.init(project=wandb_project, entity=wandb_user, dir=wandb_logdir, config=wandb_config)
             # Update exp name if explicitely specified
             if wandb_exp is not None: wandb.run.name = wandb_exp
 

From 104732fd6f46e754e7d652d72ab4b60178d9e821 Mon Sep 17 00:00:00 2001
From: Vikash Kumar <vikashplus@gmail.com>
Date: Sat, 24 Jun 2023 14:28:57 -0400
Subject: [PATCH 09/13] Users should explicitly import these envs if need be.
 They have mujoco_py dependency that all setups have

---
 mjrl/__init__.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mjrl/__init__.py b/mjrl/__init__.py
index 008b133..affb942 100644
--- a/mjrl/__init__.py
+++ b/mjrl/__init__.py
@@ -1 +1,2 @@
-import mjrl.envs
\ No newline at end of file
+# Users should explicitly import these envs if need be. They have mujoco_py dependency that all setups have
+# import mjrl.envs
\ No newline at end of file

From 6e058535b8f0fa3368ce33e7090abbf275f9307e Mon Sep 17 00:00:00 2001
From: Vikash Kumar <vikashplus@gmail.com>
Date: Sat, 24 Jun 2023 14:35:20 -0400
Subject: [PATCH 10/13] Users should explicitly import these envs if need be.
 They have mujoco_py dependency that all setups have

---
 mjrl/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mjrl/__init__.py b/mjrl/__init__.py
index affb942..00e188e 100644
--- a/mjrl/__init__.py
+++ b/mjrl/__init__.py
@@ -1,2 +1,2 @@
-# Users should explicitly import these envs if need be. They have mujoco_py dependency that all setups have
+# Users should explicitly import these envs if need be. They have mujoco_py dependency that not all setups have
 # import mjrl.envs
\ No newline at end of file

From 593e870388673ba33b1c3c2a03d8e5e382b4b1c3 Mon Sep 17 00:00:00 2001
From: ElyasYassin <164119976+ElyasYassin@users.noreply.github.com>
Date: Sat, 18 May 2024 23:52:22 -0600
Subject: [PATCH 11/13] Added Neural Recordings

Records neuron activation after each forward pass
---
 mjrl/utils/fc_network.py | 104 +++++++++++++--------------------------
 1 file changed, 33 insertions(+), 71 deletions(-)

diff --git a/mjrl/utils/fc_network.py b/mjrl/utils/fc_network.py
index c93231c..19f59cc 100644
--- a/mjrl/utils/fc_network.py
+++ b/mjrl/utils/fc_network.py
@@ -2,15 +2,14 @@
 import torch
 import torch.nn as nn
 
-
 class FCNetwork(nn.Module):
     def __init__(self, obs_dim, act_dim,
                  hidden_sizes=(64,64),
                  nonlinearity='tanh',   # either 'tanh' or 'relu'
-                 in_shift = None,
-                 in_scale = None,
-                 out_shift = None,
-                 out_scale = None):
+                 in_shift=None,
+                 in_scale=None,
+                 out_shift=None,
+                 out_scale=None):
         super(FCNetwork, self).__init__()
 
         self.obs_dim = obs_dim
@@ -20,31 +19,44 @@ def __init__(self, obs_dim, act_dim,
         self.set_transformations(in_shift, in_scale, out_shift, out_scale)
 
         # hidden layers
-        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1]) \
-                         for i in range(len(self.layer_sizes) -1)])
+        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1])
+                                         for i in range(len(self.layer_sizes) -1)])
         self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
 
+        # Register hooks for each layer
+        self.activation_hooks = []
+        for layer in self.fc_layers:
+            hook = layer.register_forward_hook(self.record_activation)
+            self.activation_hooks.append(hook)
+
+        self.activations = []
+
+    def record_activation(self, module, input, output):
+        # Record the activations during forward pass
+        self.activations.append(output)
+
     def set_transformations(self, in_shift=None, in_scale=None, out_shift=None, out_scale=None):
         # store native scales that can be used for resets
         self.transformations = dict(in_shift=in_shift,
-                           in_scale=in_scale,
-                           out_shift=out_shift,
-                           out_scale=out_scale
-                          )
-        self.in_shift  = torch.from_numpy(np.float32(in_shift)) if in_shift is not None else torch.zeros(self.obs_dim)
-        self.in_scale  = torch.from_numpy(np.float32(in_scale)) if in_scale is not None else torch.ones(self.obs_dim)
+                                    in_scale=in_scale,
+                                    out_shift=out_shift,
+                                    out_scale=out_scale
+                                    )
+        self.in_shift = torch.from_numpy(np.float32(in_shift)) if in_shift is not None else torch.zeros(self.obs_dim)
+        self.in_scale = torch.from_numpy(np.float32(in_scale)) if in_scale is not None else torch.ones(self.obs_dim)
         self.out_shift = torch.from_numpy(np.float32(out_shift)) if out_shift is not None else torch.zeros(self.act_dim)
         self.out_scale = torch.from_numpy(np.float32(out_scale)) if out_scale is not None else torch.ones(self.act_dim)
 
     def forward(self, x):
-        try:
-            out = x.to(self.device)
-        except:
-            if hasattr(self, 'device') == False:
-                self.device = 'cpu'
-                out = x.to(self.device)
-            else:
-                raise TypeError
+        # Reset activations
+        self.activations = []
+
+        # TODO(Aravind): Remove clamping to CPU
+        # This is a temp change that should be fixed shortly
+        if x.is_cuda:
+            out = x.to('cpu')
+        else:
+            out = x
         out = (out - self.in_shift)/(self.in_scale + 1e-8)
         for i in range(len(self.fc_layers)-1):
             out = self.fc_layers[i](out)
@@ -52,53 +64,3 @@ def forward(self, x):
         out = self.fc_layers[-1](out)
         out = out * self.out_scale + self.out_shift
         return out
-
-    def to(self, device):
-        self.device = device
-        # change the transforms to the appropriate device
-        self.in_shift = self.in_shift.to(device)
-        self.in_scale = self.in_scale.to(device)
-        self.out_shift = self.out_shift.to(device)
-        self.out_scale = self.out_scale.to(device)
-        # move all other trainable parameters to device
-        super().to(device)
-
-
-class FCNetworkWithBatchNorm(nn.Module):
-    def __init__(self, obs_dim, act_dim,
-                 hidden_sizes=(64,64),
-                 nonlinearity='relu',   # either 'tanh' or 'relu'
-                 dropout=0,           # probability to dropout activations (0 means no dropout)
-                 *args, **kwargs,
-                ):
-        super(FCNetworkWithBatchNorm, self).__init__()
-
-        self.obs_dim = obs_dim
-        self.act_dim = act_dim
-        assert type(hidden_sizes) == tuple
-        self.layer_sizes = (obs_dim, ) + hidden_sizes + (act_dim, )
-        self.device = 'cpu'
-
-        # hidden layers
-        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1]) \
-                         for i in range(len(self.layer_sizes) -1)])
-        self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
-        self.input_batchnorm = nn.BatchNorm1d(num_features=obs_dim)
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(self, x):
-        out = x.to(self.device)
-        out = self.input_batchnorm(out)
-        for i in range(len(self.fc_layers)-1):
-            out = self.fc_layers[i](out)
-            out = self.dropout(out)
-            out = self.nonlinearity(out)
-        out = self.fc_layers[-1](out)
-        return out
-
-    def to(self, device):
-        self.device = device
-        super().to(device)
-
-    def set_transformations(self, *args, **kwargs):
-        pass

From bf6f7a14b5270831700e26cdfe9ee3d2f5f76e76 Mon Sep 17 00:00:00 2001
From: ElyasYassin <164119976+ElyasYassin@users.noreply.github.com>
Date: Sun, 26 May 2024 21:33:04 -0600
Subject: [PATCH 12/13] fixed error in network #1

---
 mjrl/utils/fc_network.py | 119 ++++++++++++++++++++++++++++++---------
 1 file changed, 92 insertions(+), 27 deletions(-)

diff --git a/mjrl/utils/fc_network.py b/mjrl/utils/fc_network.py
index 19f59cc..6b5837b 100644
--- a/mjrl/utils/fc_network.py
+++ b/mjrl/utils/fc_network.py
@@ -1,6 +1,7 @@
-import numpy as np
 import torch
 import torch.nn as nn
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
 
 class FCNetwork(nn.Module):
     def __init__(self, obs_dim, act_dim,
@@ -9,58 +10,122 @@ def __init__(self, obs_dim, act_dim,
                  in_shift=None,
                  in_scale=None,
                  out_shift=None,
-                 out_scale=None):
+                 out_scale=None,
+                 log_dir='runs/activations'):
         super(FCNetwork, self).__init__()
 
         self.obs_dim = obs_dim
         self.act_dim = act_dim
         assert type(hidden_sizes) == tuple
-        self.layer_sizes = (obs_dim, ) + hidden_sizes + (act_dim, )
+        self.layer_sizes = (obs_dim,) + hidden_sizes + (act_dim,)
         self.set_transformations(in_shift, in_scale, out_shift, out_scale)
 
-        # hidden layers
         self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1])
-                                         for i in range(len(self.layer_sizes) -1)])
+                                        for i in range(len(self.layer_sizes) - 1)])
         self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
-
-        # Register hooks for each layer
-        self.activation_hooks = []
-        for layer in self.fc_layers:
-            hook = layer.register_forward_hook(self.record_activation)
-            self.activation_hooks.append(hook)
-
-        self.activations = []
-
-    def record_activation(self, module, input, output):
-        # Record the activations during forward pass
-        self.activations.append(output)
+        
+        self.activations = {}
+        self.writer = SummaryWriter(log_dir)
 
     def set_transformations(self, in_shift=None, in_scale=None, out_shift=None, out_scale=None):
-        # store native scales that can be used for resets
         self.transformations = dict(in_shift=in_shift,
                                     in_scale=in_scale,
                                     out_shift=out_shift,
-                                    out_scale=out_scale
-                                    )
+                                    out_scale=out_scale)
         self.in_shift = torch.from_numpy(np.float32(in_shift)) if in_shift is not None else torch.zeros(self.obs_dim)
         self.in_scale = torch.from_numpy(np.float32(in_scale)) if in_scale is not None else torch.ones(self.obs_dim)
         self.out_shift = torch.from_numpy(np.float32(out_shift)) if out_shift is not None else torch.zeros(self.act_dim)
         self.out_scale = torch.from_numpy(np.float32(out_scale)) if out_scale is not None else torch.ones(self.act_dim)
 
     def forward(self, x):
-        # Reset activations
-        self.activations = []
-
-        # TODO(Aravind): Remove clamping to CPU
-        # This is a temp change that should be fixed shortly
         if x.is_cuda:
             out = x.to('cpu')
         else:
             out = x
-        out = (out - self.in_shift)/(self.in_scale + 1e-8)
-        for i in range(len(self.fc_layers)-1):
+        out = (out - self.in_shift) / (self.in_scale + 1e-8)
+        for i in range(len(self.fc_layers) - 1):
             out = self.fc_layers[i](out)
             out = self.nonlinearity(out)
         out = self.fc_layers[-1](out)
         out = out * self.out_scale + self.out_shift
         return out
+
+    def hook_fn(self, layer_name):
+        def hook(module, input, output):
+            self.activations[layer_name] = output.detach().cpu().numpy()
+            self.writer.add_histogram(f'Activations/{layer_name}', output)
+        return hook
+
+    def register_hooks(self):
+        for i, layer in enumerate(self.fc_layers):
+            layer.register_forward_hook(self.hook_fn(f'fc_layer_{i}'))
+
+    def close_writer(self):
+        self.writer.close()
+
+    def to(self, device):
+        self.device = device
+        # change the transforms to the appropriate device
+        self.in_shift = self.in_shift.to(device)
+        self.in_scale = self.in_scale.to(device)
+        self.out_shift = self.out_shift.to(device)
+        self.out_scale = self.out_scale.to(device)
+        # move all other trainable parameters to device
+        super().to(device) 
+
+
+class FCNetworkWithBatchNorm(nn.Module):
+    def __init__(self, obs_dim, act_dim,
+                 hidden_sizes=(64,64),
+                 nonlinearity='relu',   # either 'tanh' or 'relu'
+                 dropout=0,           # probability to dropout activations (0 means no dropout)
+                 log_dir='runs/activations_with_batchnorm',
+                 *args, **kwargs):
+        super(FCNetworkWithBatchNorm, self).__init__()
+
+        self.obs_dim = obs_dim
+        self.act_dim = act_dim
+        assert type(hidden_sizes) == tuple
+        self.layer_sizes = (obs_dim, ) + hidden_sizes + (act_dim, )
+        self.device = 'cpu'
+
+        # hidden layers
+        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1]) \
+                         for i in range(len(self.layer_sizes) - 1)])
+        self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
+        self.input_batchnorm = nn.BatchNorm1d(num_features=obs_dim)
+        self.dropout = nn.Dropout(dropout)
+        
+        self.activations = {}
+        self.writer = SummaryWriter(log_dir)
+
+    def forward(self, x):
+        out = x.to(self.device)
+        out = self.input_batchnorm(out)
+        for i in range(len(self.fc_layers) - 1):
+            out = self.fc_layers[i](out)
+            out = self.dropout(out)
+            out = self.nonlinearity(out)
+        out = self.fc_layers[-1](out)
+        return out
+
+    def hook_fn(self, layer_name):
+        def hook(module, input, output):
+            self.activations[layer_name] = output.detach().cpu().numpy()
+            self.writer.add_histogram(f'Activations/{layer_name}', output)
+        return hook
+
+    def register_hooks(self):
+        for i, layer in enumerate(self.fc_layers):
+            layer.register_forward_hook(self.hook_fn(f'fc_layer_{i}'))
+
+    def close_writer(self):
+        self.writer.close()
+
+    def to(self, device):
+        self.device = device
+        super().to(device)
+
+    def set_transformations(self, *args, **kwargs):
+        pass
+

From e791e181aca90e7d2023a5e052e509f09cf21a0c Mon Sep 17 00:00:00 2001
From: LyesBesylex <elyaslarfi4am8@gmail.com>
Date: Mon, 10 Jun 2024 12:43:59 -0600
Subject: [PATCH 13/13] Added neural activation return

---
 mjrl/algos/trpo.py                            |   1 -
 mjrl/policies/gaussian_mlp.py                 |  84 ++++++--------
 mjrl/utils/fc_network.py                      | 105 ++++++++++++------
 ...ents.out.tfevents.1718004269.Elyas.31368.0 | Bin 0 -> 88 bytes
 ...ents.out.tfevents.1718004290.Elyas.31024.0 | Bin 0 -> 88 bytes
 5 files changed, 105 insertions(+), 85 deletions(-)
 create mode 100644 mjrl/utils/runs/activations/events.out.tfevents.1718004269.Elyas.31368.0
 create mode 100644 mjrl/utils/runs/activations/events.out.tfevents.1718004290.Elyas.31024.0

diff --git a/mjrl/algos/trpo.py b/mjrl/algos/trpo.py
index 7b46e88..1574750 100644
--- a/mjrl/algos/trpo.py
+++ b/mjrl/algos/trpo.py
@@ -12,7 +12,6 @@
 
 # samplers
 import mjrl.samplers.core as trajectory_sampler
-import mjrl.samplers.batch_sampler as batch_sampler
 
 # utility functions
 import mjrl.utils.process_samples as process_samples
diff --git a/mjrl/policies/gaussian_mlp.py b/mjrl/policies/gaussian_mlp.py
index a165bb7..51fee8a 100644
--- a/mjrl/policies/gaussian_mlp.py
+++ b/mjrl/policies/gaussian_mlp.py
@@ -3,60 +3,42 @@
 import torch
 from torch.autograd import Variable
 
-
 class MLP:
     def __init__(self, env_spec,
                  hidden_sizes=(64,64),
                  min_log_std=-3,
                  init_log_std=0,
                  seed=None):
-        """
-        :param env_spec: specifications of the env (see utils/gym_env.py)
-        :param hidden_sizes: network hidden layer sizes (currently 2 layers only)
-        :param min_log_std: log_std is clamped at this value and can't go below
-        :param init_log_std: initial log standard deviation
-        :param seed: random seed
-        """
         self.n = env_spec.observation_dim  # number of states
         self.m = env_spec.action_dim  # number of actions
         self.min_log_std = min_log_std
 
-        # Set seed
-        # ------------------------
         if seed is not None:
             torch.manual_seed(seed)
             np.random.seed(seed)
 
-        # Policy network
-        # ------------------------
         self.model = FCNetwork(self.n, self.m, hidden_sizes)
-        # make weights small
         for param in list(self.model.parameters())[-2:]:  # only last layer
            param.data = 1e-2 * param.data
         self.log_std = Variable(torch.ones(self.m) * init_log_std, requires_grad=True)
         self.trainable_params = list(self.model.parameters()) + [self.log_std]
 
-        # Old Policy network
-        # ------------------------
         self.old_model = FCNetwork(self.n, self.m, hidden_sizes)
         self.old_log_std = Variable(torch.ones(self.m) * init_log_std)
         self.old_params = list(self.old_model.parameters()) + [self.old_log_std]
         for idx, param in enumerate(self.old_params):
             param.data = self.trainable_params[idx].data.clone()
 
-        # Easy access variables
-        # -------------------------
         self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
         self.param_shapes = [p.data.numpy().shape for p in self.trainable_params]
         self.param_sizes = [p.data.numpy().size for p in self.trainable_params]
         self.d = np.sum(self.param_sizes)  # total number of params
 
-        # Placeholders
-        # ------------------------
         self.obs_var = Variable(torch.randn(self.n), requires_grad=False)
 
-    # Utility functions
-    # ============================================
+    def show_activations(self):
+        return self.model.activations
+
     def get_param_values(self):
         params = np.concatenate([p.contiguous().view(-1).data.numpy()
                                  for p in self.trainable_params])
@@ -70,10 +52,8 @@ def set_param_values(self, new_params, set_new=True, set_old=True):
                 vals = vals.reshape(self.param_shapes[idx])
                 param.data = torch.from_numpy(vals).float()
                 current_idx += self.param_sizes[idx]
-            # clip std at minimum value
             self.trainable_params[-1].data = \
                 torch.clamp(self.trainable_params[-1], self.min_log_std).data
-            # update log_std_val for sampling
             self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
         if set_old:
             current_idx = 0
@@ -82,12 +62,9 @@ def set_param_values(self, new_params, set_new=True, set_old=True):
                 vals = vals.reshape(self.param_shapes[idx])
                 param.data = torch.from_numpy(vals).float()
                 current_idx += self.param_sizes[idx]
-            # clip std at minimum value
             self.old_params[-1].data = \
                 torch.clamp(self.old_params[-1], self.min_log_std).data
 
-    # Main functions
-    # ============================================
     def get_action(self, observation):
         o = np.float32(observation.reshape(1, -1))
         self.obs_var.data = torch.from_numpy(o)
@@ -144,6 +121,12 @@ def mean_kl(self, new_dist_info, old_dist_info):
         sample_kl = torch.sum(Nr / Dr + new_log_std - old_log_std, dim=1)
         return torch.mean(sample_kl)
 
+    # Ensure to close the writer when done
+    def close_writer(self):
+        self.model.close_writer()
+        self.old_model.close_writer()
+
+
 
 class BatchNormMLP(MLP):
     def __init__(self, env_spec,
@@ -153,53 +136,52 @@ def __init__(self, env_spec,
                  seed=None,
                  nonlinearity='relu',
                  dropout=0,
-                 *args, **kwargs,
-                 ):
-        """
-        :param env_spec: specifications of the env (see utils/gym_env.py)
-        :param hidden_sizes: network hidden layer sizes (currently 2 layers only)
-        :param min_log_std: log_std is clamped at this value and can't go below
-        :param init_log_std: initial log standard deviation
-        :param seed: random seed
-        """
-        # super(BatchNormMLP, self).__init__()
-
+                 log_dir='runs/activations_with_batchnorm',
+                 *args, **kwargs):
         self.n = env_spec.observation_dim  # number of states
         self.m = env_spec.action_dim  # number of actions
         self.min_log_std = min_log_std
 
-        # Set seed
-        # ------------------------
         if seed is not None:
             torch.manual_seed(seed)
             np.random.seed(seed)
 
-        # Policy network
-        # ------------------------
-        self.model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout)
-        # make weights small
+        self.model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout, log_dir=log_dir)
+
         for param in list(self.model.parameters())[-2:]:  # only last layer
            param.data = 1e-2 * param.data
         self.log_std = Variable(torch.ones(self.m) * init_log_std, requires_grad=True)
         self.trainable_params = list(self.model.parameters()) + [self.log_std]
         self.model.eval()
 
-        # Old Policy network
-        # ------------------------
-        self.old_model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout)
+        self.old_model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout, log_dir=log_dir)
         self.old_log_std = Variable(torch.ones(self.m) * init_log_std)
         self.old_params = list(self.old_model.parameters()) + [self.old_log_std]
         for idx, param in enumerate(self.old_params):
             param.data = self.trainable_params[idx].data.clone()
         self.old_model.eval()
 
-        # Easy access variables
-        # -------------------------
         self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
         self.param_shapes = [p.data.numpy().shape for p in self.trainable_params]
         self.param_sizes = [p.data.numpy().size for p in self.trainable_params]
         self.d = np.sum(self.param_sizes)  # total number of params
 
-        # Placeholders
-        # ------------------------
-        self.obs_var = Variable(torch.randn(self.n), requires_grad=False)
\ No newline at end of file
+        self.obs_var = Variable(torch.randn(self.n), requires_grad=False)
+
+        # Register hooks to log activations
+        self.model.register_hooks()
+        self.old_model.register_hooks()
+        self.close_writer()
+
+    def get_action(self, observation):
+        o = np.float32(observation.reshape(1, -1))
+        self.obs_var.data = torch.from_numpy(o)
+        mean = self.model(self.obs_var).data.numpy().ravel()
+        noise = np.exp(self.log_std_val) * np.random.randn(self.m)
+        action = mean + noise
+        return [action, {'mean': mean, 'log_std': self.log_std_val, 'evaluation': mean}]
+
+    # Ensure to close the writer when done
+    def close_writer(self):
+        self.model.close_writer()
+        self.old_model.close_writer()
diff --git a/mjrl/utils/fc_network.py b/mjrl/utils/fc_network.py
index 6b5837b..9a3c07c 100644
--- a/mjrl/utils/fc_network.py
+++ b/mjrl/utils/fc_network.py
@@ -5,8 +5,8 @@
 
 class FCNetwork(nn.Module):
     def __init__(self, obs_dim, act_dim,
-                 hidden_sizes=(64,64),
-                 nonlinearity='tanh',   # either 'tanh' or 'relu'
+                 hidden_sizes=(64, 64),
+                 nonlinearity='sigmoid',  # either 'tanh' or 'relu' or 'sigmoid'
                  in_shift=None,
                  in_scale=None,
                  out_shift=None,
@@ -20,10 +20,19 @@ def __init__(self, obs_dim, act_dim,
         self.layer_sizes = (obs_dim,) + hidden_sizes + (act_dim,)
         self.set_transformations(in_shift, in_scale, out_shift, out_scale)
 
-        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1])
+        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i + 1])
                                         for i in range(len(self.layer_sizes) - 1)])
-        self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
-        
+        if nonlinearity == 'relu':
+            self.nonlinearity = torch.relu 
+        elif nonlinearity == 'tanh':
+            self.nonlinearity = torch.tanh
+        elif nonlinearity == 'sigmoid':
+            self.nonlinearity = torch.sigmoid
+        elif nonlinearity == 'softplus':
+            self.nonlinearity = nn.functional.softplus
+        else:
+            raise ValueError("Nonlinearity must be 'tanh', 'relu', or 'sigmoid'")
+
         self.activations = {}
         self.writer = SummaryWriter(log_dir)
 
@@ -38,27 +47,24 @@ def set_transformations(self, in_shift=None, in_scale=None, out_shift=None, out_
         self.out_scale = torch.from_numpy(np.float32(out_scale)) if out_scale is not None else torch.ones(self.act_dim)
 
     def forward(self, x):
-        if x.is_cuda:
-            out = x.to('cpu')
-        else:
-            out = x
+        try:
+            out = x.to(self.device)
+        except:
+            if hasattr(self, 'device') == False:
+                self.device = 'cpu'
+                out = x.to(self.device)
+            else:
+                raise TypeError
         out = (out - self.in_shift) / (self.in_scale + 1e-8)
         for i in range(len(self.fc_layers) - 1):
             out = self.fc_layers[i](out)
             out = self.nonlinearity(out)
+            self.activations["layer_"+ str(i)] = out.detach().cpu().numpy() 
         out = self.fc_layers[-1](out)
         out = out * self.out_scale + self.out_shift
-        return out
-
-    def hook_fn(self, layer_name):
-        def hook(module, input, output):
-            self.activations[layer_name] = output.detach().cpu().numpy()
-            self.writer.add_histogram(f'Activations/{layer_name}', output)
-        return hook
+        self.activations["output_layer"] = out.detach().cpu().numpy() 
+        return out #Outputs final layer
 
-    def register_hooks(self):
-        for i, layer in enumerate(self.fc_layers):
-            layer.register_forward_hook(self.hook_fn(f'fc_layer_{i}'))
 
     def close_writer(self):
         self.writer.close()
@@ -71,14 +77,25 @@ def to(self, device):
         self.out_shift = self.out_shift.to(device)
         self.out_scale = self.out_scale.to(device)
         # move all other trainable parameters to device
-        super().to(device) 
+        super().to(device)
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        # Remove the SummaryWriter from the state
+        if 'writer' in state:
+            del state['writer']
+        return state
 
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        # Recreate the SummaryWriter (won't be in pickled state)
+        self.writer = SummaryWriter()
 
 class FCNetworkWithBatchNorm(nn.Module):
     def __init__(self, obs_dim, act_dim,
-                 hidden_sizes=(64,64),
-                 nonlinearity='relu',   # either 'tanh' or 'relu'
-                 dropout=0,           # probability to dropout activations (0 means no dropout)
+                 hidden_sizes=(64, 64),
+                 nonlinearity='relu',  # either 'tanh' or 'relu'
+                 dropout=0,  # probability to dropout activations (0 means no dropout)
                  log_dir='runs/activations_with_batchnorm',
                  *args, **kwargs):
         super(FCNetworkWithBatchNorm, self).__init__()
@@ -86,16 +103,16 @@ def __init__(self, obs_dim, act_dim,
         self.obs_dim = obs_dim
         self.act_dim = act_dim
         assert type(hidden_sizes) == tuple
-        self.layer_sizes = (obs_dim, ) + hidden_sizes + (act_dim, )
+        self.layer_sizes = (obs_dim,) + hidden_sizes + (act_dim,)
         self.device = 'cpu'
 
         # hidden layers
-        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1]) \
-                         for i in range(len(self.layer_sizes) - 1)])
+        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i + 1])
+                                        for i in range(len(self.layer_sizes) - 1)])
         self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
         self.input_batchnorm = nn.BatchNorm1d(num_features=obs_dim)
         self.dropout = nn.Dropout(dropout)
-        
+
         self.activations = {}
         self.writer = SummaryWriter(log_dir)
 
@@ -107,17 +124,28 @@ def forward(self, x):
             out = self.dropout(out)
             out = self.nonlinearity(out)
         out = self.fc_layers[-1](out)
+        print(out)
         return out
 
-    def hook_fn(self, layer_name):
-        def hook(module, input, output):
-            self.activations[layer_name] = output.detach().cpu().numpy()
-            self.writer.add_histogram(f'Activations/{layer_name}', output)
-        return hook
+    def hook_fn(self, module, input, output, layer_name, neuron_idx=None):
+        self.activations[layer_name] = output.detach().cpu().numpy()
+        self.writer.add_histogram(f'Activations/{layer_name}', output)
+
+    def hook_fn_layer(self, module, input, output):
+        layer_name = module.__class__.__name__
+        self.hook_fn(module, input, output, layer_name)
+
+    def hook_fn_neuron(self, module, input, output):
+        layer_name = module.__class__.__name__
+        neuron_idx = module.neuron_idx
+        self.hook_fn(module, input, output, layer_name, neuron_idx)
 
     def register_hooks(self):
         for i, layer in enumerate(self.fc_layers):
-            layer.register_forward_hook(self.hook_fn(f'fc_layer_{i}'))
+            layer_name = f'fc_layer_{i}'
+            layer.__class__.__name__ = layer_name
+            layer.register_forward_hook(self.hook_fn_layer)
+            
 
     def close_writer(self):
         self.writer.close()
@@ -129,3 +157,14 @@ def to(self, device):
     def set_transformations(self, *args, **kwargs):
         pass
 
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        # Remove the SummaryWriter from the state
+        if 'writer' in state:
+            del state['writer']
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        # Recreate the SummaryWriter (won't be in pickled state)
+        self.writer = SummaryWriter()
diff --git a/mjrl/utils/runs/activations/events.out.tfevents.1718004269.Elyas.31368.0 b/mjrl/utils/runs/activations/events.out.tfevents.1718004269.Elyas.31368.0
new file mode 100644
index 0000000000000000000000000000000000000000..2633da828ac9a732977abf191db1cbd2a927d41c
GIT binary patch
literal 88
zcmeZZfPjCKJmzv*rIdEBnt9VviZ`h!F*8rkwJbHS#L6g0k4vW{HLp0oC@DX&C`GTh
hG&eV~s8X-ID6=HBNG}znDn2bUCp8`-a{s>bOaL9-Al3i?

literal 0
HcmV?d00001

diff --git a/mjrl/utils/runs/activations/events.out.tfevents.1718004290.Elyas.31024.0 b/mjrl/utils/runs/activations/events.out.tfevents.1718004290.Elyas.31024.0
new file mode 100644
index 0000000000000000000000000000000000000000..a5976413fd153bd1a1f3aa1f4241cb79acee16a6
GIT binary patch
literal 88
zcmeZZfPjCKJmzvfw(Xg)YUWKxDc+=_#LPTB*Rs^S5-X!1JuaP+)V$*SqNM!9q7=R2
h(%js{qDsB;qRf)iBE3|Qs`#|boYZ)TNKWa$O8`o?A&US2

literal 0
HcmV?d00001