Skip to content

Commit

Permalink
Rewrite core/predict.py to use lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
NickleDave committed Nov 30, 2022
1 parent 782b7b3 commit 9f6017d
Showing 1 changed file with 47 additions and 5 deletions.
52 changes: 47 additions & 5 deletions src/vak/core/predict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import functools
import json
import logging
import os
from pathlib import Path

import crowsetta
import joblib
import lightning
import numpy as np
import pandas as pd
from tqdm import tqdm
Expand Down Expand Up @@ -156,7 +158,7 @@ def predict(
item_transform=item_transform,
)

pred_data = torch.utils.data.DataLoader(
pred_loader = torch.utils.data.DataLoader(
dataset=pred_dataset,
shuffle=False,
# batch size 1 because each spectrogram reshaped into a batch of windows
Expand Down Expand Up @@ -189,12 +191,52 @@ def predict(
for model_name, model in models_map.items():
# ---------------- do the actual predicting --------------------------------------------------------------------
logger.info(f"loading checkpoint for {model_name} from path: {checkpoint_path}")
model.load(checkpoint_path, device=device)
logger.info(f"running predict method of {model_name}")
pred_dict = model.predict(pred_data=pred_data, device=device)

# TODO: move all this logic down into another function
# TODO: this should be 'get_windowed_frame_classification_model_from_config(model_name, model_config)`
model_config = model_config_map[model_name]
num_classes = len(labelmap)
model_config['network'].update({'num_classes': num_classes, 'input_shape': input_shape})
if model_name == 'TweetyNet':
# TODO: make `network` decorator, get these from a namespace it register them in
network = models.TweetyNet(**model_config['network'])
elif model_name == 'TeenyTweetyNet':
network = models.TeenyTweetyNet(**model_config['network'])
else:
raise ValueError(f'unknown model name: {model_name}')
loss_func = torch.nn.CrossEntropyLoss(**model_config['loss'])
optimizer_config = model_config['optimizer']
lbl_tb2labels = functools.partial(labeled_timebins.lbl_tb2labels, labels_mapping=labelmap)
if model_name == 'TweetyNet':
# TODO: make `network` decorator, get these from a namespace it register them in
model = models.TweetyNetModel
elif model_name == 'TeenyTweetyNet':
model = models.TeenyTweetyNetModel
model = model.load_from_checkpoint(checkpoint_path,
network=network,
loss_func=loss_func,
optimizer_config=optimizer_config,
lbl_tb2labels=lbl_tb2labels)

if device == 'cuda':
accelerator = 'gpu'
else:
accelerator = None
trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(
save_dir=output_dir
)
trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger)

logger.info(f"running predict method of {model_name}")
results = trainer.predict(model, pred_loader)
# TODO: figure out how to overload `on_predict_epoch_end` to return dict
pred_dict = {
spect_path: y_pred
for result in results
for spect_path, y_pred in result.items()
}
# ---------------- converting to annotations ------------------------------------------------------------------
progress_bar = tqdm(pred_data)
progress_bar = tqdm(pred_loader)

annots = []
logger.info("converting predictions to annotations")
Expand Down

0 comments on commit 9f6017d

Please sign in to comment.