Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Squad inference example #390

Merged
merged 4 commits into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ files, along with unit tests, examples and tutorials

- Added TRADE (dialogue state tracking model) on MultiWOZ dataset
([PR #322](https://github.com/NVIDIA/NeMo/pull/322)) - @chiphuyen, @VahidooX
- Question answering:
([PR #390](https://github.com/NVIDIA/NeMo/pull/390)) - @yzhang123
- Changed question answering task to use Roberta and Albert as alternative backends to Bert
- Added inference mode that does not require ground truth labels

### Dependencies Update
- Added dependency on `wrapt` (the new version of the `deprecated` warning) - @tkornuta-nvidia, @DEKHTIARJonathan
Expand Down
196 changes: 137 additions & 59 deletions examples/nlp/question_answering/question_answering_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
--weight_decay 0.0
--lr 3e-5
--do_lower_case
--mode train_eval

If --bert_checkpoint is not specified, training starts from
Huggingface pretrained checkpoints.
Expand All @@ -55,17 +56,34 @@
--optimizer adam_w
--weight_decay 0.0
--lr 3e-5
--mode train_eval
--do_lower_case

On Huggingface the final Exact Match (EM) and F1 scores are as follows:
Model EM F1
BERT Based uncased 80.59 88.34
BERT Large uncased 83.88 90.65

To run only evaluation on pretrained question answering checkpoints on 1 GPU with ground-truth data:
python question_answering_squad.py
--dev_file /path_to_data_dir/infer.json
--checkpoint_dir /path_to_checkpoints
--do_lower_case
--mode eval

To run only inference on pretrained question answering checkpoints on 1 GPU without ground-truth data:
python question_answering_squad.py
--infer_file /path_to_data_dir/infer.json
--checkpoint_dir /path_to_checkpoints
--do_lower_case
--mode infer
"""
import argparse
import json
import os

import numpy as np

import nemo.collections.nlp as nemo_nlp
import nemo.core as nemo_core
from nemo import logging
Expand All @@ -79,7 +97,12 @@ def parse_args():
"--train_file", type=str, help="The training data file. Should be *.json",
)
parser.add_argument(
"--dev_file", type=str, required=True, help="The evaluation data file. Should be *.json",
"--dev_file", type=str, help="The evaluation data file. Should be *.json",
)
parser.add_argument(
"--infer_file",
type=str,
help="The inference data file. Should be *.json. Does not need to contain ground truth",
)
parser.add_argument("--pretrained_model_name", type=str, help="Name of the pre-trained model")
parser.add_argument("--checkpoint_dir", default=None, type=str, help="Checkpoint directory for inference.")
Expand Down Expand Up @@ -115,7 +138,9 @@ def parse_args():
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.",
)
parser.add_argument("--evaluation_only", action='store_true', help="Whether to only do evaluation.")
parser.add_argument(
"--mode", default="train_eval", choices=["train_eval", "eval", "infer"], help="Mode of model usage."
)
parser.add_argument(
"--no_data_cache", action='store_true', help="When specified do not load and store cache preprocessed data.",
)
Expand Down Expand Up @@ -209,15 +234,15 @@ def create_pipeline(
data_file,
model,
head,
loss_fn,
max_query_length,
max_seq_length,
doc_stride,
batch_size,
version_2_with_negative,
mode,
num_gpus=1,
batches_per_step=1,
mode="train",
loss_fn=None,
use_data_cache=True,
):
data_layer = nemo_nlp.nm.data_layers.BertQuestionAnsweringDataLayer(
Expand All @@ -239,17 +264,26 @@ def create_pipeline(
)

qa_output = head(hidden_states=hidden_states)
loss_output = loss_fn(
logits=qa_output, start_positions=input_data.start_positions, end_positions=input_data.end_positions
)

steps_per_epoch = len(data_layer) // (batch_size * num_gpus * batches_per_step)
return (
loss_output.loss,
steps_per_epoch,
[loss_output.start_logits, loss_output.end_logits, input_data.unique_ids],
data_layer,
)

if mode == "infer":
return (
steps_per_epoch,
[input_data.unique_ids, qa_output],
data_layer,
)
else:
loss_output = loss_fn(
logits=qa_output, start_positions=input_data.start_positions, end_positions=input_data.end_positions
)

return (
loss_output.loss,
steps_per_epoch,
[input_data.unique_ids, loss_output.start_logits, loss_output.end_logits],
data_layer,
)


MODEL_CLASSES = {
Expand All @@ -261,14 +295,24 @@ def create_pipeline(

if __name__ == "__main__":
args = parse_args()
if not os.path.exists(args.dev_file):
raise FileNotFoundError(
"eval data not found. Datasets can be obtained using examples/nlp/scripts/get_squad.py"
)
if not args.evaluation_only and not os.path.exists(args.train_file):
raise FileNotFoundError(
"train data not found. Datasets can be obtained using examples/nlp/scripts/get_squad.py"
)

if args.mode == "train_eval":
if not os.path.exists(args.train_file) or not os.path.exists(args.dev_file):
raise FileNotFoundError(
"train and dev data not found. Datasets can be obtained using examples/nlp/scripts/get_squad.py"
)
elif args.mode == "eval":
if not os.path.exists(args.dev_file):
raise FileNotFoundError(
"dev data not found. Datasets can be obtained using examples/nlp/scripts/get_squad.py"
)
elif args.mode == "infer":
if not os.path.exists(args.infer_file):
raise FileNotFoundError(
"infer data not found. Datasets can be obtained using examples/nlp/scripts/get_squad.py"
)
else:
raise ValueError(f"{args.mode} can only be one of [train_eval, eval, infer]")

# Instantiate neural factory with supported backend
nf = nemo_core.NeuralModuleFactory(
Expand Down Expand Up @@ -328,7 +372,7 @@ def create_pipeline(
if args.bert_checkpoint is not None:
model.restore_from(args.bert_checkpoint)

if not args.evaluation_only:
if "train" in args.mode:
train_loss, train_steps_per_epoch, _, _ = create_pipeline(
data_file=args.train_file,
model=model,
Expand All @@ -344,24 +388,39 @@ def create_pipeline(
mode="train",
use_data_cache=not args.no_data_cache,
)
logging.info(f"training step per epoch: {train_steps_per_epoch}")
_, _, eval_output, eval_data_layer = create_pipeline(
data_file=args.dev_file,
model=model,
head=qa_head,
loss_fn=squad_loss,
max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
batch_size=args.batch_size,
version_2_with_negative=args.version_2_with_negative,
num_gpus=args.num_gpus,
batches_per_step=args.batches_per_step,
mode="dev",
use_data_cache=not args.no_data_cache,
)
if "eval" in args.mode:
_, _, eval_output, eval_data_layer = create_pipeline(
data_file=args.dev_file,
model=model,
head=qa_head,
loss_fn=squad_loss,
max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
batch_size=args.batch_size,
version_2_with_negative=args.version_2_with_negative,
num_gpus=args.num_gpus,
batches_per_step=args.batches_per_step,
mode="dev",
use_data_cache=not args.no_data_cache,
)
if "infer" in args.mode:
_, eval_output, infer_data_layer = create_pipeline(
data_file=args.infer_file,
model=model,
head=qa_head,
max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
batch_size=args.batch_size,
version_2_with_negative=args.version_2_with_negative,
num_gpus=args.num_gpus,
batches_per_step=args.batches_per_step,
mode="infer",
use_data_cache=not args.no_data_cache,
)

if not args.evaluation_only:
if args.mode == "train_eval":
logging.info(f"steps_per_epoch = {train_steps_per_epoch}")
callback_train = nemo_core.SimpleLossLoggerCallback(
tensors=[train_loss],
Expand Down Expand Up @@ -402,33 +461,52 @@ def create_pipeline(
batches_per_step=args.batches_per_step,
optimization_params={"num_epochs": args.num_epochs, "lr": args.lr},
)
else:

else:
load_from_folder = None
if args.checkpoint_dir is not None:
load_from_folder = args.checkpoint_dir

evaluated_tensors = nf.infer(tensors=eval_output, checkpoint_dir=load_from_folder, cache=True)
unique_ids = []
start_logits = []
end_logits = []
for t in evaluated_tensors[2]:
unique_ids.extend(t.tolist())
for t in evaluated_tensors[0]:
start_logits.extend(t.tolist())
for t in evaluated_tensors[1]:
end_logits.extend(t.tolist())

exact_match, f1, all_predictions = eval_data_layer.dataset.evaluate(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits,
n_best_size=args.n_best_size,
max_answer_length=args.max_answer_length,
version_2_with_negative=args.version_2_with_negative,
null_score_diff_threshold=args.null_score_diff_threshold,
do_lower_case=args.do_lower_case,
)
unique_ids.extend(t.tolist())
if "eval" in args.mode:
start_logits = []
end_logits = []
for t in evaluated_tensors[1]:
start_logits.extend(t.tolist())
for t in evaluated_tensors[2]:
end_logits.extend(t.tolist())

exact_match, f1, all_predictions = eval_data_layer.dataset.evaluate(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits,
n_best_size=args.n_best_size,
max_answer_length=args.max_answer_length,
version_2_with_negative=args.version_2_with_negative,
null_score_diff_threshold=args.null_score_diff_threshold,
do_lower_case=args.do_lower_case,
)

logging.info(f"exact_match: {exact_match}, f1: {f1}")
logging.info(f"exact_match: {exact_match}, f1: {f1}")

elif "infer" in args.mode:
logits = []
for t in evaluated_tensors[1]:
logits.extend(t.tolist())
start_logits, end_logits = np.split(np.asarray(logits), 2, axis=-1)
(all_predictions, all_nbest_json, scores_diff_json) = infer_data_layer.dataset.get_predictions(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits,
n_best_size=args.n_best_size,
max_answer_length=args.max_answer_length,
version_2_with_negative=args.version_2_with_negative,
null_score_diff_threshold=args.null_score_diff_threshold,
do_lower_case=args.do_lower_case,
)
if args.output_prediction_file is not None:
with open(args.output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
30 changes: 19 additions & 11 deletions nemo/collections/nlp/data/datasets/qa_squad_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def __init__(
self.version_2_with_negative = version_2_with_negative
self.processor = SquadProcessor(data_file=data_file, mode=mode)
self.mode = mode
if mode != "dev" and mode != "train":
raise ValueError(f"mode should be either 'train' or 'dev' but got {mode}")
if mode not in ["dev", "train", "infer"]:
raise ValueError(f"mode should be either 'train', 'dev', or 'infer' but got {mode}")
self.examples = self.processor.get_examples()

cached_features_file = (
Expand All @@ -107,7 +107,7 @@ def __init__(
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
has_groundtruth=True,
has_groundtruth=mode != "infer",
)

if use_cache:
Expand All @@ -122,14 +122,22 @@ def __len__(self):

def __getitem__(self, idx):
feature = self.features[idx]
return (
np.array(feature.input_ids),
np.array(feature.segment_ids),
np.array(feature.input_mask),
np.array(feature.start_position),
np.array(feature.end_position),
np.array(feature.unique_id),
)
if self.mode == "infer":
return (
np.array(feature.input_ids),
np.array(feature.segment_ids),
np.array(feature.input_mask),
np.array(feature.unique_id),
)
else:
return (
np.array(feature.input_ids),
np.array(feature.segment_ids),
np.array(feature.input_mask),
np.array(feature.unique_id),
np.array(feature.start_position),
np.array(feature.end_position),
)

def get_predictions(
self,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/nm/data_layers/qa_squad_datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def output_ports(self):
"input_ids": NeuralType(('B', 'T'), ChannelType()),
"input_type_ids": NeuralType(('B', 'T'), ChannelType()),
"input_mask": NeuralType(('B', 'T'), ChannelType()),
"start_positions": NeuralType(tuple('B'), ChannelType()),
"end_positions": NeuralType(tuple('B'), ChannelType()),
"unique_ids": NeuralType(tuple('B'), ChannelType()),
"start_positions": NeuralType(tuple('B'), ChannelType(), optional=True),
"end_positions": NeuralType(tuple('B'), ChannelType(), optional=True),
Comment on lines +63 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@okuchaiev What exactly does optional mean? How does it work if I only return end_positions but not start_positions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_positions, end_positions are the "labels", they are optional to make the data_layer more general for inference cases where i do not have the groundtruth.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, I just didn't know we had this option, and want to know how it works.

}

def __init__(
Expand Down