Skip to content

Commit

Permalink
Use model from models_map in core/predict.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NickleDave committed Dec 25, 2022
1 parent 88a95e9 commit 97c24c2
Showing 1 changed file with 2 additions and 27 deletions.
29 changes: 2 additions & 27 deletions src/vak/core/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,38 +191,13 @@ def predict(
for model_name, model in models_map.items():
# ---------------- do the actual predicting --------------------------------------------------------------------
logger.info(f"loading checkpoint for {model_name} from path: {checkpoint_path}")

# TODO: move all this logic down into another function
# TODO: this should be 'get_windowed_frame_classification_model_from_config(model_name, model_config)`
model_config = model_config_map[model_name]
num_classes = len(labelmap)
model_config['network'].update({'num_classes': num_classes, 'input_shape': input_shape})
if model_name == 'TweetyNet':
# TODO: make `network` decorator, get these from a namespace it register them in
network = models.TweetyNet(**model_config['network'])
elif model_name == 'TeenyTweetyNet':
network = models.TeenyTweetyNet(**model_config['network'])
else:
raise ValueError(f'unknown model name: {model_name}')
loss_func = torch.nn.CrossEntropyLoss(**model_config['loss'])
optimizer_config = model_config['optimizer']
lbl_tb2labels = functools.partial(labeled_timebins.lbl_tb2labels, labels_mapping=labelmap)
if model_name == 'TweetyNet':
# TODO: make `network` decorator, get these from a namespace it register them in
model = models.TweetyNetModel
elif model_name == 'TeenyTweetyNet':
model = models.TeenyTweetyNetModel
model = model.load_from_checkpoint(checkpoint_path,
network=network,
loss_func=loss_func,
optimizer_config=optimizer_config,
lbl_tb2labels=lbl_tb2labels)
model.load_state_dict_from_path(checkpoint_path)

if device == 'cuda':
accelerator = 'gpu'
else:
accelerator = None
trainer_logger = lightning.pytorch.loggers.TensorBoardLogger(
trainer_logger = lightning.loggers.TensorBoardLogger(
save_dir=output_dir
)
trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger)
Expand Down

0 comments on commit 97c24c2

Please sign in to comment.