Skip to content

Commit

Permalink
prep_mentormix
Browse files Browse the repository at this point in the history
  • Loading branch information
Liphos committed Sep 23, 2022
1 parent 91111fb commit 9012486
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
plt.clf()

import torch
import wandb
from torch.utils.tensorboard import SummaryWriter

from dataset import import_dataset, add_impurity
Expand Down Expand Up @@ -61,12 +62,18 @@ def train_epoch(config:Dict[str, Union[str,int, List[int]]], model:torch.nn.Modu
sig_mask = labels==1
if writer is not None:
key = "test" if is_testing else "train"
writer.add_scalar("Loss_"+ key, loss, (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar("Metrics_"+ key +"/Accuracy", accuracy, (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar("Metrics_"+ key +"/TPR", torch.mean(torch.where(torch.round(outputs[sig_mask])==1, 1., 0.)), (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar("Metrics_"+ key +"/TNR", torch.mean(torch.where(torch.round(outputs[~sig_mask])==0, 1., 0.)), (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar("Metrics_"+ key +"/TNR", torch.mean(torch.where(torch.round(outputs[~sig_mask])==0, 1., 0.)), (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar(f"Loss_{key}", loss, (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar(f"Metrics_{key}/Accuracy", accuracy, (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar(f"Metrics_{key}/TPR", torch.mean(torch.where(torch.round(outputs[sig_mask])==1, 1., 0.)), (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar(f"Metrics_{key}/TNR", torch.mean(torch.where(torch.round(outputs[~sig_mask])==0, 1., 0.)), (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)
writer.add_scalar("lr", lr_scheduler.get_last_lr()[0], (int(size/batch_size) + 1 ) * epoch + i, new_style=True if i==0 else False)

wandb.log({f"Loss_{key}": loss,
f"Metrics_{key}/Accuracy": accuracy,
f"Metrics_{key}/TPR": torch.mean(torch.where(torch.round(outputs[sig_mask])==1, 1., 0.)),
f"Metrics_{key}/TNR": torch.mean(torch.where(torch.round(outputs[~sig_mask])==0, 1., 0.)),
"lr": lr_scheduler.get_last_lr()[0],
})

mean_loss = (mean_loss * mean_counter + loss )/(mean_counter + 1)
mean_accuracy = (mean_accuracy * mean_counter + accuracy )/(mean_counter + 1)
Expand Down Expand Up @@ -96,6 +103,9 @@ def train_epoch(config:Dict[str, Union[str,int, List[int]]], model:torch.nn.Modu
#Gather configs from config file
config = read_config(args["config"])

#initiate wandb
wandb.init(project="trend", entity="liphos", config=config, name=config["comment"],)
wandb.config = config
#Set seed
np.random.seed(config["seed"])
print("Using " + config["device"] + " device")
Expand Down Expand Up @@ -214,15 +224,16 @@ def train_epoch(config:Dict[str, Union[str,int, List[int]]], model:torch.nn.Modu
new_labels[labels_mean>=0.8] = 1
new_labels[labels_mean<=1-0.8] = 0
elif config["training"]["extra_args"]["mode"] == "unanimity":
new_labels[logical_and_arrays([torch.where(model(data_tensor)[:, 0]>=0.7, True, False) for model in models])] = 1
new_labels[logical_and_arrays([torch.where(model(data_tensor)[:, 0]<=1-0.7, True, False) for model in models])] = 0
new_labels[logical_and_arrays([np.where(model(data_tensor)[:, 0].detach().cpu().numpy()>=0.7, True, False) for model in models])] = 1
new_labels[logical_and_arrays([np.where(model(data_tensor)[:, 0].detach().cpu().numpy()<=1-0.7, True, False) for model in models])] = 0
else:
raise ValueError("This mode for relabelling doesn't exist.")
new_nb_correct = np.where(label_train_clean==new_labels)[0].shape[0]

print(f"Previous labels:{np.where(new_labels[:,0]==labels_train[:,0])[0].shape}/{len(labels_train)}")
print(f"Correct labels: {new_nb_correct}/{len(new_labels)}")
print(f"Pourcentage of correct labels in labels that were changed: {(1+(new_nb_correct-old_nb_correct)/(len(labels_train) - np.where(new_labels[:,0]==labels_train[:,0])[0].shape[0]))/2}")
if (len(labels_train) - np.where(new_labels[:,0]==labels_train[:,0])[0].shape[0]) >0:
print(f"Pourcentage of correct labels in labels that were changed: {(1+(new_nb_correct-old_nb_correct)/(len(labels_train) - np.where(new_labels[:,0]==labels_train[:,0])[0].shape[0]))/2}")

old_nb_correct = new_nb_correct
labels_train = new_labels
Expand Down

0 comments on commit 9012486

Please sign in to comment.