diff --git a/examples/example_configs/hopper_npg.txt b/examples/example_configs/hopper_npg.txt
index bd98381..bf16679 100644
--- a/examples/example_configs/hopper_npg.txt
+++ b/examples/example_configs/hopper_npg.txt
@@ -29,5 +29,11 @@
 
 'alg_hyper_params'  :   dict(),
 
+'wandb_params':  {
+    'use_wandb'     : True,
+    'wandb_user'    : 'vikashplus',
+    'wandb_project' : 'mjrl_demo',
+    'wandb_exp'     : 'demo_exp',
+    }
 }
 
diff --git a/examples/example_configs/swimmer_npg.txt b/examples/example_configs/swimmer_npg.txt
index f8af3a8..8c09008 100644
--- a/examples/example_configs/swimmer_npg.txt
+++ b/examples/example_configs/swimmer_npg.txt
@@ -29,4 +29,10 @@
 
 'alg_hyper_params'  :   dict(),
 
+'wandb_params':  {
+    'use_wandb'     : True,
+    'wandb_user'    : 'vikashplus',
+    'wandb_project' : 'mjrl_demo',
+    'wandb_exp'     : 'demo_exp',
+    }
 }
\ No newline at end of file
diff --git a/examples/example_configs/swimmer_ppo.txt b/examples/example_configs/swimmer_ppo.txt
index e20d592..6c561e0 100644
--- a/examples/example_configs/swimmer_ppo.txt
+++ b/examples/example_configs/swimmer_ppo.txt
@@ -7,7 +7,7 @@
 'seed'              :   123,
 'sample_mode'       :   'trajectories',
 'rl_num_traj'       :   10,
-'rl_num_iter'       :   50,
+'rl_num_iter'       :   10,
 'num_cpu'           :   2,
 'save_freq'         :   25,
 'eval_rollouts'     :   None,
@@ -29,4 +29,10 @@
 
 'alg_hyper_params'  :   dict(clip_coef=0.2, epochs=10, mb_size=64, learn_rate=5e-4),
 
+'wandb_params':  {
+    'use_wandb'     : True,
+    'wandb_user'    : 'vikashplus',
+    'wandb_project' : 'mjrl_demo',
+    'wandb_exp'     : 'demo_exp',
+    }
 }
\ No newline at end of file
diff --git a/examples/policy_opt_job_script.py b/examples/policy_opt_job_script.py
index 0ee68df..e0ae249 100644
--- a/examples/policy_opt_job_script.py
+++ b/examples/policy_opt_job_script.py
@@ -13,6 +13,7 @@
 from mjrl.algos.batch_reinforce import BatchREINFORCE
 from mjrl.algos.ppo_clip import PPO
 from mjrl.utils.train_agent import train_agent
+from mjrl.utils.logger import DataLog
 import os
 import json
 import gym
@@ -82,6 +83,16 @@
     # or defaults in the PPO algorithm will be used
     agent = PPO(e, policy, baseline, save_logs=True, **job_data['alg_hyper_params'])
 
+
+# Update logger if WandB in Config
+if 'wandb_params' in job_data.keys() and job_data['wandb_params']['use_wandb']==True:
+    if 'wandb_logdir' in job_data['wandb_params']:
+        job_data['wandb_params']['wandb_logdir'] = os.path.join(JOB_DIR, job_data['wandb_params']['wandb_logdir'])
+    else:
+        job_data['wandb_params']['wandb_logdir'] = JOB_DIR
+    agent.logger = DataLog(**job_data['wandb_params'], wandb_config=job_data)
+
+
 print("========================================")
 print("Starting policy learning")
 print("========================================")
diff --git a/mjrl/__init__.py b/mjrl/__init__.py
index 008b133..00e188e 100644
--- a/mjrl/__init__.py
+++ b/mjrl/__init__.py
@@ -1 +1,2 @@
-import mjrl.envs
\ No newline at end of file
+# Users should explicitly import these envs if need be. They have mujoco_py dependency that not all setups have
+# import mjrl.envs
\ No newline at end of file
diff --git a/mjrl/algos/behavior_cloning.py b/mjrl/algos/behavior_cloning.py
index 6aac09e..bc1d512 100644
--- a/mjrl/algos/behavior_cloning.py
+++ b/mjrl/algos/behavior_cloning.py
@@ -118,6 +118,7 @@ def fit(self, data, suppress_fit_tqdm=False, **kwargs):
             self.logger.log_kv('loss_before', loss_val)
 
         # train loop
+        self.policy.model.train()
         for ep in config_tqdm(range(self.epochs), suppress_fit_tqdm):
             for mb in range(int(num_samples / self.mb_size)):
                 rand_idx = np.random.choice(num_samples, size=self.mb_size)
@@ -125,6 +126,7 @@ def fit(self, data, suppress_fit_tqdm=False, **kwargs):
                 loss = self.loss(data, idx=rand_idx)
                 loss.backward()
                 self.optimizer.step()
+        self.policy.model.eval()
         params_after_opt = self.policy.get_param_values()
         self.policy.set_param_values(params_after_opt, set_new=True, set_old=True)
 
diff --git a/mjrl/algos/trpo.py b/mjrl/algos/trpo.py
index 7b46e88..1574750 100644
--- a/mjrl/algos/trpo.py
+++ b/mjrl/algos/trpo.py
@@ -12,7 +12,6 @@
 
 # samplers
 import mjrl.samplers.core as trajectory_sampler
-import mjrl.samplers.batch_sampler as batch_sampler
 
 # utility functions
 import mjrl.utils.process_samples as process_samples
diff --git a/mjrl/policies/gaussian_mlp.py b/mjrl/policies/gaussian_mlp.py
index ae97145..51fee8a 100644
--- a/mjrl/policies/gaussian_mlp.py
+++ b/mjrl/policies/gaussian_mlp.py
@@ -1,62 +1,44 @@
 import numpy as np
-from mjrl.utils.fc_network import FCNetwork
+from mjrl.utils.fc_network import FCNetwork, FCNetworkWithBatchNorm
 import torch
 from torch.autograd import Variable
 
-
 class MLP:
     def __init__(self, env_spec,
                  hidden_sizes=(64,64),
                  min_log_std=-3,
                  init_log_std=0,
                  seed=None):
-        """
-        :param env_spec: specifications of the env (see utils/gym_env.py)
-        :param hidden_sizes: network hidden layer sizes (currently 2 layers only)
-        :param min_log_std: log_std is clamped at this value and can't go below
-        :param init_log_std: initial log standard deviation
-        :param seed: random seed
-        """
         self.n = env_spec.observation_dim  # number of states
         self.m = env_spec.action_dim  # number of actions
         self.min_log_std = min_log_std
 
-        # Set seed
-        # ------------------------
         if seed is not None:
             torch.manual_seed(seed)
             np.random.seed(seed)
 
-        # Policy network
-        # ------------------------
         self.model = FCNetwork(self.n, self.m, hidden_sizes)
-        # make weights small
         for param in list(self.model.parameters())[-2:]:  # only last layer
            param.data = 1e-2 * param.data
         self.log_std = Variable(torch.ones(self.m) * init_log_std, requires_grad=True)
         self.trainable_params = list(self.model.parameters()) + [self.log_std]
 
-        # Old Policy network
-        # ------------------------
         self.old_model = FCNetwork(self.n, self.m, hidden_sizes)
         self.old_log_std = Variable(torch.ones(self.m) * init_log_std)
         self.old_params = list(self.old_model.parameters()) + [self.old_log_std]
         for idx, param in enumerate(self.old_params):
             param.data = self.trainable_params[idx].data.clone()
 
-        # Easy access variables
-        # -------------------------
         self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
         self.param_shapes = [p.data.numpy().shape for p in self.trainable_params]
         self.param_sizes = [p.data.numpy().size for p in self.trainable_params]
         self.d = np.sum(self.param_sizes)  # total number of params
 
-        # Placeholders
-        # ------------------------
         self.obs_var = Variable(torch.randn(self.n), requires_grad=False)
 
-    # Utility functions
-    # ============================================
+    def show_activations(self):
+        return self.model.activations
+
     def get_param_values(self):
         params = np.concatenate([p.contiguous().view(-1).data.numpy()
                                  for p in self.trainable_params])
@@ -70,10 +52,8 @@ def set_param_values(self, new_params, set_new=True, set_old=True):
                 vals = vals.reshape(self.param_shapes[idx])
                 param.data = torch.from_numpy(vals).float()
                 current_idx += self.param_sizes[idx]
-            # clip std at minimum value
             self.trainable_params[-1].data = \
                 torch.clamp(self.trainable_params[-1], self.min_log_std).data
-            # update log_std_val for sampling
             self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
         if set_old:
             current_idx = 0
@@ -82,12 +62,9 @@ def set_param_values(self, new_params, set_new=True, set_old=True):
                 vals = vals.reshape(self.param_shapes[idx])
                 param.data = torch.from_numpy(vals).float()
                 current_idx += self.param_sizes[idx]
-            # clip std at minimum value
             self.old_params[-1].data = \
                 torch.clamp(self.old_params[-1], self.min_log_std).data
 
-    # Main functions
-    # ============================================
     def get_action(self, observation):
         o = np.float32(observation.reshape(1, -1))
         self.obs_var.data = torch.from_numpy(o)
@@ -143,3 +120,68 @@ def mean_kl(self, new_dist_info, old_dist_info):
         Dr = 2 * new_std ** 2 + 1e-8
         sample_kl = torch.sum(Nr / Dr + new_log_std - old_log_std, dim=1)
         return torch.mean(sample_kl)
+
+    # Ensure to close the writer when done
+    def close_writer(self):
+        self.model.close_writer()
+        self.old_model.close_writer()
+
+
+
+class BatchNormMLP(MLP):
+    def __init__(self, env_spec,
+                 hidden_sizes=(64,64),
+                 min_log_std=-3,
+                 init_log_std=0,
+                 seed=None,
+                 nonlinearity='relu',
+                 dropout=0,
+                 log_dir='runs/activations_with_batchnorm',
+                 *args, **kwargs):
+        self.n = env_spec.observation_dim  # number of states
+        self.m = env_spec.action_dim  # number of actions
+        self.min_log_std = min_log_std
+
+        if seed is not None:
+            torch.manual_seed(seed)
+            np.random.seed(seed)
+
+        self.model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout, log_dir=log_dir)
+
+        for param in list(self.model.parameters())[-2:]:  # only last layer
+           param.data = 1e-2 * param.data
+        self.log_std = Variable(torch.ones(self.m) * init_log_std, requires_grad=True)
+        self.trainable_params = list(self.model.parameters()) + [self.log_std]
+        self.model.eval()
+
+        self.old_model = FCNetworkWithBatchNorm(self.n, self.m, hidden_sizes, nonlinearity, dropout, log_dir=log_dir)
+        self.old_log_std = Variable(torch.ones(self.m) * init_log_std)
+        self.old_params = list(self.old_model.parameters()) + [self.old_log_std]
+        for idx, param in enumerate(self.old_params):
+            param.data = self.trainable_params[idx].data.clone()
+        self.old_model.eval()
+
+        self.log_std_val = np.float64(self.log_std.data.numpy().ravel())
+        self.param_shapes = [p.data.numpy().shape for p in self.trainable_params]
+        self.param_sizes = [p.data.numpy().size for p in self.trainable_params]
+        self.d = np.sum(self.param_sizes)  # total number of params
+
+        self.obs_var = Variable(torch.randn(self.n), requires_grad=False)
+
+        # Register hooks to log activations
+        self.model.register_hooks()
+        self.old_model.register_hooks()
+        self.close_writer()
+
+    def get_action(self, observation):
+        o = np.float32(observation.reshape(1, -1))
+        self.obs_var.data = torch.from_numpy(o)
+        mean = self.model(self.obs_var).data.numpy().ravel()
+        noise = np.exp(self.log_std_val) * np.random.randn(self.m)
+        action = mean + noise
+        return [action, {'mean': mean, 'log_std': self.log_std_val, 'evaluation': mean}]
+
+    # Ensure to close the writer when done
+    def close_writer(self):
+        self.model.close_writer()
+        self.old_model.close_writer()
diff --git a/mjrl/samplers/core.py b/mjrl/samplers/core.py
index be4a988..8c267c5 100644
--- a/mjrl/samplers/core.py
+++ b/mjrl/samplers/core.py
@@ -6,6 +6,7 @@
 import multiprocessing as mp
 import time as timer
 logging.disable(logging.CRITICAL)
+import gc
 
 
 # Single core rollout to sample trajectories
@@ -93,6 +94,7 @@ def do_rollout(
         paths.append(path)
 
     del(env)
+    gc.collect()
     return paths
 
 
@@ -134,7 +136,7 @@ def sample_paths(
         start_time = timer.time()
         print("####### Gathering Samples #######")
 
-    results = _try_multiprocess(do_rollout, input_dict_list,
+    results = _try_multiprocess_cf(do_rollout, input_dict_list,
                                 num_cpu, max_process_time, max_timeouts)
     paths = []
     # result is a paths type and results is list of paths
@@ -186,7 +188,7 @@ def sample_data_batch(
     return paths
 
 
-def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
+def _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
     
     # Base case
     if max_timeouts == 0:
@@ -202,9 +204,30 @@ def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_time
         pool.close()
         pool.terminate()
         pool.join()
-        return _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1)
+        return _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1)
 
     pool.close()
     pool.terminate()
     pool.join()  
     return results
+
+
+def _try_multiprocess_cf(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
+    import concurrent.futures
+    results = None
+    if max_timeouts != 0:
+        with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpu) as executor:
+            submit_futures = [executor.submit(func, **input_dict) for input_dict in input_dict_list]
+            try:
+                results = [f.result() for f in submit_futures]
+            except TimeoutError as e:
+                print(str(e))
+                print("Timeout Error raised...") 
+            except concurrent.futures.CancelledError as e:
+                print(str(e))
+                print("Future Cancelled Error raised...") 
+            except Exception as e:
+                print(str(e))
+                print("Error raised...") 
+                raise e
+    return results
\ No newline at end of file
diff --git a/mjrl/utils/fc_network.py b/mjrl/utils/fc_network.py
index ea3ad72..9a3c07c 100644
--- a/mjrl/utils/fc_network.py
+++ b/mjrl/utils/fc_network.py
@@ -1,52 +1,170 @@
-import numpy as np
 import torch
 import torch.nn as nn
-
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
 
 class FCNetwork(nn.Module):
     def __init__(self, obs_dim, act_dim,
-                 hidden_sizes=(64,64),
-                 nonlinearity='tanh',   # either 'tanh' or 'relu'
-                 in_shift = None,
-                 in_scale = None,
-                 out_shift = None,
-                 out_scale = None):
+                 hidden_sizes=(64, 64),
+                 nonlinearity='sigmoid',  # either 'tanh' or 'relu' or 'sigmoid'
+                 in_shift=None,
+                 in_scale=None,
+                 out_shift=None,
+                 out_scale=None,
+                 log_dir='runs/activations'):
         super(FCNetwork, self).__init__()
 
         self.obs_dim = obs_dim
         self.act_dim = act_dim
         assert type(hidden_sizes) == tuple
-        self.layer_sizes = (obs_dim, ) + hidden_sizes + (act_dim, )
+        self.layer_sizes = (obs_dim,) + hidden_sizes + (act_dim,)
         self.set_transformations(in_shift, in_scale, out_shift, out_scale)
 
-        # hidden layers
-        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1]) \
-                         for i in range(len(self.layer_sizes) -1)])
-        self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
+        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i + 1])
+                                        for i in range(len(self.layer_sizes) - 1)])
+        if nonlinearity == 'relu':
+            self.nonlinearity = torch.relu 
+        elif nonlinearity == 'tanh':
+            self.nonlinearity = torch.tanh
+        elif nonlinearity == 'sigmoid':
+            self.nonlinearity = torch.sigmoid
+        elif nonlinearity == 'softplus':
+            self.nonlinearity = nn.functional.softplus
+        else:
+            raise ValueError("Nonlinearity must be 'tanh', 'relu', or 'sigmoid'")
+
+        self.activations = {}
+        self.writer = SummaryWriter(log_dir)
 
     def set_transformations(self, in_shift=None, in_scale=None, out_shift=None, out_scale=None):
-        # store native scales that can be used for resets
         self.transformations = dict(in_shift=in_shift,
-                           in_scale=in_scale,
-                           out_shift=out_shift,
-                           out_scale=out_scale
-                          )
-        self.in_shift  = torch.from_numpy(np.float32(in_shift)) if in_shift is not None else torch.zeros(self.obs_dim)
-        self.in_scale  = torch.from_numpy(np.float32(in_scale)) if in_scale is not None else torch.ones(self.obs_dim)
+                                    in_scale=in_scale,
+                                    out_shift=out_shift,
+                                    out_scale=out_scale)
+        self.in_shift = torch.from_numpy(np.float32(in_shift)) if in_shift is not None else torch.zeros(self.obs_dim)
+        self.in_scale = torch.from_numpy(np.float32(in_scale)) if in_scale is not None else torch.ones(self.obs_dim)
         self.out_shift = torch.from_numpy(np.float32(out_shift)) if out_shift is not None else torch.zeros(self.act_dim)
         self.out_scale = torch.from_numpy(np.float32(out_scale)) if out_scale is not None else torch.ones(self.act_dim)
 
     def forward(self, x):
-        # TODO(Aravind): Remove clamping to CPU
-        # This is a temp change that should be fixed shortly
-        if x.is_cuda:
-            out = x.to('cpu')
-        else:
-            out = x
-        out = (out - self.in_shift)/(self.in_scale + 1e-8)
-        for i in range(len(self.fc_layers)-1):
+        try:
+            out = x.to(self.device)
+        except:
+            if hasattr(self, 'device') == False:
+                self.device = 'cpu'
+                out = x.to(self.device)
+            else:
+                raise TypeError
+        out = (out - self.in_shift) / (self.in_scale + 1e-8)
+        for i in range(len(self.fc_layers) - 1):
             out = self.fc_layers[i](out)
             out = self.nonlinearity(out)
+            self.activations["layer_"+ str(i)] = out.detach().cpu().numpy() 
         out = self.fc_layers[-1](out)
         out = out * self.out_scale + self.out_shift
+        self.activations["output_layer"] = out.detach().cpu().numpy() 
+        return out #Outputs final layer
+
+
+    def close_writer(self):
+        self.writer.close()
+
+    def to(self, device):
+        self.device = device
+        # change the transforms to the appropriate device
+        self.in_shift = self.in_shift.to(device)
+        self.in_scale = self.in_scale.to(device)
+        self.out_shift = self.out_shift.to(device)
+        self.out_scale = self.out_scale.to(device)
+        # move all other trainable parameters to device
+        super().to(device)
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        # Remove the SummaryWriter from the state
+        if 'writer' in state:
+            del state['writer']
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        # Recreate the SummaryWriter (won't be in pickled state)
+        self.writer = SummaryWriter()
+
+class FCNetworkWithBatchNorm(nn.Module):
+    def __init__(self, obs_dim, act_dim,
+                 hidden_sizes=(64, 64),
+                 nonlinearity='relu',  # either 'tanh' or 'relu'
+                 dropout=0,  # probability to dropout activations (0 means no dropout)
+                 log_dir='runs/activations_with_batchnorm',
+                 *args, **kwargs):
+        super(FCNetworkWithBatchNorm, self).__init__()
+
+        self.obs_dim = obs_dim
+        self.act_dim = act_dim
+        assert type(hidden_sizes) == tuple
+        self.layer_sizes = (obs_dim,) + hidden_sizes + (act_dim,)
+        self.device = 'cpu'
+
+        # hidden layers
+        self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i + 1])
+                                        for i in range(len(self.layer_sizes) - 1)])
+        self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh
+        self.input_batchnorm = nn.BatchNorm1d(num_features=obs_dim)
+        self.dropout = nn.Dropout(dropout)
+
+        self.activations = {}
+        self.writer = SummaryWriter(log_dir)
+
+    def forward(self, x):
+        out = x.to(self.device)
+        out = self.input_batchnorm(out)
+        for i in range(len(self.fc_layers) - 1):
+            out = self.fc_layers[i](out)
+            out = self.dropout(out)
+            out = self.nonlinearity(out)
+        out = self.fc_layers[-1](out)
+        print(out)
         return out
+
+    def hook_fn(self, module, input, output, layer_name, neuron_idx=None):
+        self.activations[layer_name] = output.detach().cpu().numpy()
+        self.writer.add_histogram(f'Activations/{layer_name}', output)
+
+    def hook_fn_layer(self, module, input, output):
+        layer_name = module.__class__.__name__
+        self.hook_fn(module, input, output, layer_name)
+
+    def hook_fn_neuron(self, module, input, output):
+        layer_name = module.__class__.__name__
+        neuron_idx = module.neuron_idx
+        self.hook_fn(module, input, output, layer_name, neuron_idx)
+
+    def register_hooks(self):
+        for i, layer in enumerate(self.fc_layers):
+            layer_name = f'fc_layer_{i}'
+            layer.__class__.__name__ = layer_name
+            layer.register_forward_hook(self.hook_fn_layer)
+            
+
+    def close_writer(self):
+        self.writer.close()
+
+    def to(self, device):
+        self.device = device
+        super().to(device)
+
+    def set_transformations(self, *args, **kwargs):
+        pass
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        # Remove the SummaryWriter from the state
+        if 'writer' in state:
+            del state['writer']
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        # Recreate the SummaryWriter (won't be in pickled state)
+        self.writer = SummaryWriter()
diff --git a/mjrl/utils/gym_env.py b/mjrl/utils/gym_env.py
index eb45902..434b7e2 100644
--- a/mjrl/utils/gym_env.py
+++ b/mjrl/utils/gym_env.py
@@ -15,9 +15,9 @@ def __init__(self, obs_dim, act_dim, horizon):
 
 class GymEnv(object):
     def __init__(self, env, env_kwargs=None,
-                 obs_mask=None, act_repeat=1, 
+                 obs_mask=None, act_repeat=1,
                  *args, **kwargs):
-    
+
         # get the correct env behavior
         if type(env) == str:
             env = gym.make(env)
@@ -30,7 +30,7 @@ def __init__(self, env, env_kwargs=None,
             raise AttributeError
 
         self.env = env
-        self.env_id = env.spec.id
+        self.env_id = env.unwrapped.spec.id
         self.act_repeat = act_repeat
 
         try:
@@ -42,14 +42,14 @@ def __init__(self, env, env_kwargs=None,
         self._horizon = self._horizon // self.act_repeat
 
         try:
-            self._action_dim = self.env.env.action_dim
-        except AttributeError:
             self._action_dim = self.env.action_space.shape[0]
+        except AttributeError:
+            self._action_dim = self.env.unwrapped.action_dim
 
         try:
-            self._observation_dim = self.env.env.obs_dim
-        except AttributeError:
             self._observation_dim = self.env.observation_space.shape[0]
+        except AttributeError:
+            self._observation_dim = self.env.unwrapped.obs_dim
 
         # Specs
         self.spec = EnvSpec(self._observation_dim, self._action_dim, self._horizon)
@@ -80,7 +80,7 @@ def horizon(self):
     def reset(self, seed=None):
         try:
             self.env._elapsed_steps = 0
-            return self.env.env.reset_model(seed=seed)
+            return self.env.unwrapped.reset_model(seed=seed)
         except:
             if seed is not None:
                 self.set_seed(seed)
@@ -92,7 +92,7 @@ def reset_model(self, seed=None):
 
     def step(self, action):
         action = action.clip(self.action_space.low, self.action_space.high)
-        if self.act_repeat == 1: 
+        if self.act_repeat == 1:
             obs, cum_reward, done, ifo = self.env.step(action)
         else:
             cum_reward = 0.0
@@ -104,8 +104,8 @@ def step(self, action):
 
     def render(self):
         try:
-            self.env.env.mujoco_render_frames = True
-            self.env.env.mj_render()
+            self.env.unwrapped.mujoco_render_frames = True
+            self.env.unwrapped.mj_render()
         except:
             self.env.render()
 
@@ -117,13 +117,13 @@ def set_seed(self, seed=123):
 
     def get_obs(self):
         try:
-            return self.obs_mask * self.env.env.get_obs()
+            return self.obs_mask * self.env.get_obs()
         except:
-            return self.obs_mask * self.env.env._get_obs()
+            return self.obs_mask * self.env._get_obs()
 
     def get_env_infos(self):
         try:
-            return self.env.env.get_env_infos()
+            return self.env.unwrapped.get_env_infos()
         except:
             return {}
 
@@ -133,19 +133,19 @@ def get_env_infos(self):
 
     def get_env_state(self):
         try:
-            return self.env.env.get_env_state()
+            return self.env.unwrapped.get_env_state()
         except:
             raise NotImplementedError
 
     def set_env_state(self, state_dict):
         try:
-            self.env.env.set_env_state(state_dict)
+            self.env.unwrapped.set_env_state(state_dict)
         except:
             raise NotImplementedError
 
     def real_env_step(self, bool_val):
         try:
-            self.env.env.real_step = bool_val
+            self.env.unwrapped.real_step = bool_val
         except:
             raise NotImplementedError
 
@@ -153,7 +153,7 @@ def real_env_step(self, bool_val):
 
     def visualize_policy(self, policy, horizon=1000, num_episodes=1, mode='exploration'):
         try:
-            self.env.env.visualize_policy(policy, horizon, num_episodes, mode)
+            self.env.unwrapped.visualize_policy(policy, horizon, num_episodes, mode)
         except:
             for ep in range(num_episodes):
                 o = self.reset()
diff --git a/mjrl/utils/logger.py b/mjrl/utils/logger.py
index 0a155ed..f96a074 100644
--- a/mjrl/utils/logger.py
+++ b/mjrl/utils/logger.py
@@ -7,15 +7,32 @@
 import os
 import csv
 
+# Defaults
+USERNAME = 'aravraj'
+WANDB_PROJECT = 'mjrl_test'
+
 class DataLog:
 
-    def __init__(self):
+    def __init__(self,
+                 use_wandb:bool = False,
+                 wandb_user:str = USERNAME,
+                 wandb_project:str = WANDB_PROJECT,
+                 wandb_exp:str = None,
+                 wandb_logdir:str = None,
+                 wandb_config:dict = dict()) -> None:
+        self.use_wandb = use_wandb
+        if use_wandb:
+            import wandb
+            self.run = wandb.init(project=wandb_project, entity=wandb_user, dir=wandb_logdir, config=wandb_config)
+            # Update exp name if explicitely specified
+            if wandb_exp is not None: wandb.run.name = wandb_exp
+
         self.log = {}
         self.max_len = 0
+        self.global_step = 0
 
     def log_kv(self, key, value):
         # logs the (key, value) pair
-
         # TODO: This implementation is error-prone:
         # it would be NOT aligned if some keys are missing during one iteration.
         if key not in self.log:
@@ -23,6 +40,8 @@ def log_kv(self, key, value):
         self.log[key].append(value)
         if len(self.log[key]) > self.max_len:
             self.max_len = self.max_len + 1
+        if self.use_wandb:
+            self.run.log({key: value}, step=self.global_step)
 
     def save_log(self, save_path):
         # TODO: Validate all lengths are the same.
@@ -56,6 +75,11 @@ def shrink_to(self, num_entries):
         assert min([len(series) for series in self.log.values()]) == \
             max([len(series) for series in self.log.values()])
 
+    def sync_log_with_wandb(self):
+        # Syncs the latest logged entries with wandb
+        latest_log = self.get_current_log()
+        self.run.log(latest_log, step=self.global_step)
+
     def read_log(self, log_path):
         assert log_path.endswith('log.csv')
 
diff --git a/mjrl/utils/runs/activations/events.out.tfevents.1718004269.Elyas.31368.0 b/mjrl/utils/runs/activations/events.out.tfevents.1718004269.Elyas.31368.0
new file mode 100644
index 0000000..2633da8
Binary files /dev/null and b/mjrl/utils/runs/activations/events.out.tfevents.1718004269.Elyas.31368.0 differ
diff --git a/mjrl/utils/runs/activations/events.out.tfevents.1718004290.Elyas.31024.0 b/mjrl/utils/runs/activations/events.out.tfevents.1718004290.Elyas.31024.0
new file mode 100644
index 0000000..a597641
Binary files /dev/null and b/mjrl/utils/runs/activations/events.out.tfevents.1718004290.Elyas.31024.0 differ
diff --git a/mjrl/utils/tensor_utils.py b/mjrl/utils/tensor_utils.py
index 8b0002a..fc7b90f 100644
--- a/mjrl/utils/tensor_utils.py
+++ b/mjrl/utils/tensor_utils.py
@@ -61,7 +61,7 @@ def high_res_normalize(probs):
 
 
 def stack_tensor_list(tensor_list):
-    return np.array(tensor_list)
+    return np.array(tensor_list, dtype=object)
     # tensor_shape = np.array(tensor_list[0]).shape
     # if tensor_shape is tuple():
     #     return np.array(tensor_list)
diff --git a/mjrl/utils/train_agent.py b/mjrl/utils/train_agent.py
index 688b638..651a894 100644
--- a/mjrl/utils/train_agent.py
+++ b/mjrl/utils/train_agent.py
@@ -1,6 +1,3 @@
-import logging
-logging.disable(logging.CRITICAL)
-
 from tabulate import tabulate
 from mjrl.utils.make_train_plots import make_train_plots
 from mjrl.utils.gym_env import GymEnv
@@ -94,9 +91,12 @@ def train_agent(job_name, agent,
     if i_start:
         print("Resuming from an existing job folder ...")
 
+    env_samples = 0
     for i in range(i_start, niter):
         print("......................................................................................")
         print("ITERATION : %i " % i)
+        if agent.logger.use_wandb:
+            agent.logger.global_step += 1
 
         if train_curve[i-1] > best_perf:
             best_policy = copy.deepcopy(agent.policy)
@@ -106,6 +106,10 @@ def train_agent(job_name, agent,
         args = dict(N=N, sample_mode=sample_mode, gamma=gamma, gae_lambda=gae_lambda, num_cpu=num_cpu)
         stats = agent.train_step(**args)
         train_curve[i] = stats[0]
+        # log total number of samples so far for convinience
+        iter_samples = agent.logger.get_current_log()['num_samples']
+        env_samples += iter_samples
+        agent.logger.log_kv('env_samples', env_samples)
 
         if evaluation_rollouts is not None and evaluation_rollouts > 0:
             print("Performing evaluation rollouts ........")
diff --git a/setup/env.yml b/setup/env.yml
index 5bc4397..04e4159 100644
--- a/setup/env.yml
+++ b/setup/env.yml
@@ -7,7 +7,7 @@ dependencies:
 - pip
 - ipython
 - mkl-service
-- pytorch==1.4
+- pytorch==1.9.0
 - tabulate
 - termcolor
 - torchvision