diff --git a/fairseq/models/model_utils.py b/fairseq/models/model_utils.py index 9831efbd15..432f81ea3d 100644 --- a/fairseq/models/model_utils.py +++ b/fairseq/models/model_utils.py @@ -108,7 +108,7 @@ def fill_tensors(x, mask, y, padding_idx: int): x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx) x[mask] = y elif x.size(1) > y.size(1): - x[mask] = torch.tensor(padding_idx) + x[mask] = torch.tensor(padding_idx).type_as(x) if x.dim() == 2: x[mask, :y.size(1)] = y else: