Skip to content

Commit

Permalink
Misc fixes (#2524)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2524

Reviewed By: ngoyal2707

Differential Revision: D23318746

Pulled By: myleott

fbshipit-source-id: 6db6a87aac178847bd0da26db09b1a63632a724f
  • Loading branch information
myleott authored and facebook-github-bot committed Aug 31, 2020
1 parent 0989eca commit fe1b1bb
Show file tree
Hide file tree
Showing 14 changed files with 58 additions and 37 deletions.
24 changes: 13 additions & 11 deletions fairseq/benchmark/dummy_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=49996, type=int)
parser.add_argument('--dataset-size', default=100000, type=int)
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments '
'per sample for BERT dataset')
parser.add_argument('--src-len', default=30, type=int)
parser.add_argument('--tgt-len', default=30, type=int)

def __init__(self, args, dictionary):
super().__init__(args)
Expand All @@ -34,10 +33,8 @@ def __init__(self, args, dictionary):

dictionary.pad_to_multiple_(8) # often faster if divisible by 8

seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1

self.dummy_src = seq[:-1]
self.dummy_tgt = seq[1:]
self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1

@classmethod
def setup_task(cls, args, **kwargs):
Expand All @@ -46,34 +43,39 @@ def setup_task(cls, args, **kwargs):
for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i))
logger.info('dictionary: {} types'.format(len(dictionary)))

args.max_source_positions = args.src_len + dictionary.pad() + 2
args.max_target_positions = args.tgt_len + dictionary.pad() + 2

return cls(args, dictionary)

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
item_size = max(self.args.src_len, self.args.tgt_len)
if self.args.max_sentences is not None:
bsz = self.args.max_sentences
else:
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
bsz = max(1, self.args.max_tokens // item_size)
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
self.datasets[split] = DummyDataset(
{
'id': 1,
'net_input': {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full(
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
(bsz, ), self.args.src_len, dtype=torch.long
),
'prev_output_tokens': tgt.clone(),
},
'target': tgt,
'nsentences': bsz,
'ntokens': bsz * self.args.tokens_per_sample,
'ntokens': bsz * self.args.tgt_len,
},
num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample,
item_size=item_size,
)

@property
Expand Down
4 changes: 2 additions & 2 deletions fairseq/data/denoising_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def permute_sentences(self, source, p=1.0):
full_stops[-2] = 1

# Tokens that are full stops, where the previous token is not
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero() + 2
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
result = source.clone()

num_sentences = sentence_ends.size(0)
Expand Down Expand Up @@ -271,7 +271,7 @@ def add_whole_word_mask(self, source, p):
else:
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero()
word_starts = is_word_start.nonzero(as_tuple=False)
indices = word_starts[torch.randperm(word_starts.size(0))[:num_to_mask]].squeeze(1)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio

Expand Down
4 changes: 1 addition & 3 deletions fairseq/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def distributed_init(args):
xm.rendezvous('distributed_init') # wait for all workers
xm.mark_step()

if is_master(args):
logging.getLogger().setLevel(logging.INFO)
else:
if not is_master(args):
logging.getLogger().setLevel(logging.WARNING)

if args.model_parallel_size > 1:
Expand Down
7 changes: 6 additions & 1 deletion fairseq/model_parallel/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def build_model(cls, args, task):
if not hasattr(args, 'max_positions'):
args.max_positions = args.tokens_per_sample

if getattr(args, 'untie_weights_roberta', False):
raise NotImplementedError(
'--untie-weights-roberta is not supported in model parallel mode'
)

encoder = ModelParallelRobertaEncoder(args, task.source_dictionary)
return cls(args, encoder)

Expand Down Expand Up @@ -127,7 +132,7 @@ def forward(self, features, masked_tokens=None, **kwargs):
x = self.activation_fn(x)
x = self.layer_norm(x)

features = copy_to_model_parallel_region(features)
x = copy_to_model_parallel_region(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight)
x = gather_from_model_parallel_region(x).contiguous()
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def forward(

# pack embedded source tokens into a PackedSequence
packed_x = nn.utils.rnn.pack_padded_sequence(
x, src_lengths.data, enforce_sorted=enforce_sorted
x, src_lengths.cpu(), enforce_sorted=enforce_sorted
)

# apply LSTM
Expand Down
6 changes: 5 additions & 1 deletion fairseq/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,11 @@ def __init__(self, args, dictionary):
embed_dim=args.encoder_embed_dim,
output_dim=len(dictionary),
activation_fn=args.activation_fn,
weight=self.sentence_encoder.embed_tokens.weight if not args.untie_weights_roberta else None,
weight=(
self.sentence_encoder.embed_tokens.weight
if not args.untie_weights_roberta
else None
),
)

def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
Expand Down
6 changes: 3 additions & 3 deletions fairseq/optim/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def _rms(self, tensor):
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (
exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)
).rsqrt_()
c_factor = exp_avg_sq_col.rsqrt()
return torch.mm(r_factor.unsqueeze(-1), c_factor.unsqueeze(0))
).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)

def step(self, closure=None):
"""Performs a single optimization step.
Expand Down
3 changes: 2 additions & 1 deletion fairseq/optim/fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class _FP16OptimizerMixin(object):
def __init__(self, *args, **kwargs):
# forward __init__ call to the next class in mro(method resolution order)
super().__init__(*args, **kwargs)
self._multiply_factor = 1.

@property
def has_flat_params(self):
Expand Down Expand Up @@ -135,7 +136,7 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
self._multiply_factor *= max_norm / grad_norm

self.scaler.check_overflow(grad_norm)
else:
elif max_norm > 0.0:
clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
self._multiply_factor *= clip_coef

Expand Down
2 changes: 1 addition & 1 deletion fairseq_cli/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
)
logger = logging.getLogger('fairseq_cli.eval_lm')

Expand Down
2 changes: 1 addition & 1 deletion fairseq_cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _main(args, output_file):
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
stream=output_file,
)
logger = logging.getLogger('fairseq_cli.generate')
Expand Down
2 changes: 1 addition & 1 deletion fairseq_cli/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
stream=sys.stdout,
)
logger = logging.getLogger('fairseq_cli.interactive')
Expand Down
2 changes: 1 addition & 1 deletion fairseq_cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
stream=sys.stdout,
)
logger = logging.getLogger('fairseq_cli.preprocess')
Expand Down
28 changes: 19 additions & 9 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import argparse
import logging
import math
import os
import random
import sys

Expand All @@ -32,7 +33,7 @@
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.train")
Expand Down Expand Up @@ -229,16 +230,26 @@ def train(args, trainer, task, epoch_itr):

def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch):
num_updates = trainer.get_num_updates()
max_update = args.max_update or math.inf
do_save = (
args.save_interval_updates > 0
and num_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates >= args.validate_after_updates
) or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0)
(end_of_epoch and epoch_itr.epoch % args.save_interval == 0)
or num_updates >= max_update
or (
args.save_interval_updates > 0
and num_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates >= args.validate_after_updates
)
)
do_validate = (
(not end_of_epoch and do_save) # validate during mid-epoch saves
or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0)
or (args.validate_interval_updates > 0 and num_updates > 0 and num_updates % args.validate_interval_updates == 0)
or num_updates >= max_update
or (
args.validate_interval_updates > 0
and num_updates > 0
and num_updates % args.validate_interval_updates == 0
)
) and not args.disable_validation

# Validate
Expand All @@ -247,10 +258,9 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

# Stopping conditions
max_update = args.max_update or math.inf
should_stop = (
should_stop_early(args, valid_losses[0])
or trainer.get_num_updates() >= max_update
or num_updates >= max_update
or (
args.stop_time_hours > 0
and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours
Expand Down
3 changes: 2 additions & 1 deletion fairseq_cli/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from itertools import chain
import logging
import os
import sys

import torch
Expand All @@ -18,7 +19,7 @@
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
stream=sys.stdout,
)
logger = logging.getLogger('fairseq_cli.validate')
Expand Down

0 comments on commit fe1b1bb

Please sign in to comment.