Skip to content

Commit

Permalink
MentorMix
Browse files Browse the repository at this point in the history
  • Loading branch information
Liphos committed Sep 26, 2022
1 parent 9012486 commit 8a9f0cc
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
64 changes: 47 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.tensorboard import SummaryWriter

from dataset import import_dataset, add_impurity
from utils.utils import create_batch_tensorboard, logical_and_arrays, focal_loss
from utils.utils import create_batch_tensorboard, logical_and_arrays, focal_loss, mentorMixLoss
from yml_reader import read_config

def train_epoch(config:Dict[str, Union[str,int, List[int]]], model:torch.nn.Module, training_iter:int, epoch:int, data:np.ndarray, data_labels:np.ndarray, optimizer=None, lr_scheduler=None, criterion:torch.nn=None, writer:SummaryWriter=None, is_testing:bool=False, add_gauss_noise:float=0):
Expand All @@ -38,6 +38,7 @@ def train_epoch(config:Dict[str, Union[str,int, List[int]]], model:torch.nn.Modu
mean_loss = 0
mean_accuracy = 0
mean_counter = 0
loss_p_previous = 0
for i in range(int(size/batch_size)+1):
inputs, labels = data[i*batch_size: np.minimum((i+1)*batch_size, size)], data_labels[i*batch_size: np.minimum((i+1)*batch_size, size)] #We normalize the inputs
# Every data instance is an input + label pair
Expand All @@ -52,11 +53,20 @@ def train_epoch(config:Dict[str, Union[str,int, List[int]]], model:torch.nn.Modu
else:
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = criterion(outputs, labels)

if config["training"]["use_mentorMix"]:
loss, outputs, loss_p_previous = mentorMixLoss(model, inputs, labels, loss_p_previous, config)
else:
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = criterion(outputs, labels)

loss.backward()
# Adjust learning weights
optimizer.step()
lr_scheduler.step()
# Gather data and report

accuracy = torch.mean(torch.where(torch.round(outputs)==labels, 1., 0.))

sig_mask = labels==1
Expand All @@ -67,25 +77,23 @@ def train_epoch(config:Dict[str, Union[str,int, List[int]]], model:torch.nn.Modu
writer.add_scalar(f"Metrics_{key}/TPR", torch.mean(torch.where(torch.round(outputs[sig_mask])==1, 1., 0.)), (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar(f"Metrics_{key}/TNR", torch.mean(torch.where(torch.round(outputs[~sig_mask])==0, 1., 0.)), (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar("lr", lr_scheduler.get_last_lr()[0], (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)

wandb.log({f"Loss_{key}": loss,
f"Metrics_{key}/Accuracy": accuracy,
f"Metrics_{key}/TPR": torch.mean(torch.where(torch.round(outputs[sig_mask])==1, 1., 0.)),
f"Metrics_{key}/TNR": torch.mean(torch.where(torch.round(outputs[~sig_mask])==0, 1., 0.)),
"lr": lr_scheduler.get_last_lr()[0],
})

if config["training"]["use_mentorMix"]:
wandb.log({f"ema_{key}": loss_p_previous})


mean_loss = (mean_loss * mean_counter + loss )/(mean_counter + 1)
mean_accuracy = (mean_accuracy * mean_counter + accuracy )/(mean_counter + 1)
mean_counter += 1

if not is_testing:
loss.backward()
# Adjust learning weights
optimizer.step()
lr_scheduler.step()
# Gather data and report


if not config["training"]["no_print"] and ((i % 30 == 29) or ((i+1) * batch_size >= size)):
loss, current = loss.item(), i * batch_size + len(inputs)
key = "test" if is_testing else "train"
Expand Down Expand Up @@ -134,6 +142,7 @@ def train_epoch(config:Dict[str, Union[str,int, List[int]]], model:torch.nn.Modu
if config["training"]["mode"] == "cross_training":
labels_train = labels_train_dict['clean']
cross_training = config["training"]["extra_args"]["nb_models"]
performance = np.zeros((cross_training, config["training"]["num_epochs"], 2))
for training_iter in range(cross_training):
if cross_training != 1 and config["training"]["extra_args"]["shared_data"]<1:
shared_data = config["training"]["extra_args"]["shared_data"]
Expand Down Expand Up @@ -164,14 +173,35 @@ def train_epoch(config:Dict[str, Union[str,int, List[int]]], model:torch.nn.Modu
print(f"training_iter: [{training_iter+1}/{cross_training}], epoch: {epoch}, lr: {lr_scheduler.get_last_lr()}")
train_epoch_initializer(epoch=epoch, data=data_train_split, data_labels=labels_train_split, add_gauss_noise=config["training"]["extra_args"]["add_gauss_noise"])
train_epoch_initializer(epoch=epoch, data=data_test, data_labels=labels_test, is_testing=True)
if epoch % 2 == 0:
torch.save(model.state_dict(), tensorboard_log_dir + "/checkpoint" + str(epoch) +"_" + str(training_iter) + ".pth")


#We test the model and save the performance to plot it.
with torch.no_grad():
model.eval()
outputs_train = model(torch.as_tensor(data_train, dtype=torch.float32, device=config["device"])).detach().cpu().numpy()
performance[training_iter, epoch, 0] = np.mean(np.where(np.round(outputs_train)==labels_train_split, 1., 0.))

outputs_test = model(torch.as_tensor(data_test, dtype=torch.float32, device=config["device"])).detach().cpu().numpy()
performance[training_iter, epoch, 1] = np.mean(np.where(np.round(outputs_test)==labels_test, 1., 0.))
model.train()
if epoch % 2 == 0:
torch.save(model.state_dict(), tensorboard_log_dir + "/checkpoint" + str(epoch) +"_" + str(training_iter) + ".pth")

models.append(model)
if writer is not None:
writer.flush()
writer.close()

plt.errorbar([i for i in range(1, config["training"]["num_epochs"] + 1)], np.mean(performance[:, :, 0], axis=0), yerr=np.std(performance[:, :, 0], axis=0), fmt="-o", capsize=10, label="train")
plt.errorbar([i for i in range(1, config["training"]["num_epochs"] + 1)], np.mean(performance[:, :, 1], axis=0), yerr=np.std(performance[:, :, 0], axis=0), fmt="-o", capsize=10, label="validation")
plt.xlabel("Number of Epochs")
plt.ylabel("Accuracy( fraction of 1)")
plt.title("Mean performance on trend data")
plt.legend(loc ='upper left')
plt.show(block=True)

data_test_tensor = torch.as_tensor(data_test, dtype=torch.float32, device=config["device"])
labels_test_ = torch.as_tensor(labels_test, dtype=torch.float32, device=config["device"])

elif config["training"]["mode"] == "relabelling":
labels_train = labels_train_dict['noisy']
label_train_clean = labels_train_dict['clean']
Expand Down
4 changes: 2 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(self, in_channels:int, resblock:F.Module=_ResBlock1d, kernel_size:i
)

self.layer2 = F.Sequential(
resblock(16, 32, kernel_size=7, downsample=True),
resblock(32, 32, kernel_size=7, downsample=False)
resblock(16, 32, kernel_size=kernel_size, downsample=True),
resblock(32, 32, kernel_size=kernel_size, downsample=False)
)

self.layer3 = F.Sequential(
Expand Down

0 comments on commit 8a9f0cc

Please sign in to comment.