Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolved #38 #56

Merged
merged 1 commit into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions openspeech/decoders/rnn_transducer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@ def forward(
* hidden_states (torch.FloatTensor): A hidden state of decoders. `FloatTensor` of size
``(batch, seq_length, dimension)``
"""
batch_size, input_lengths = inputs.size(0), inputs.size(1)

if input_lengths != 1:
inputs = inputs[inputs != self.eos_id].view(batch_size, -1)

embedded = self.embedding(inputs)

if hidden_states is not None:
Expand Down
8 changes: 4 additions & 4 deletions openspeech/decoders/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def __init__(
) -> None:
super(TransformerDecoderLayer, self).__init__()
self.self_attention_prenorm = nn.LayerNorm(d_model)
self.encoder_attention_prenorm = nn.LayerNorm(d_model)
self.decoder_attention_prenorm = nn.LayerNorm(d_model)
self.feed_forward_prenorm = nn.LayerNorm(d_model)
self.self_attention = MultiHeadAttention(d_model, num_heads)
self.encoder_attention = MultiHeadAttention(d_model, num_heads)
self.decoder_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout_p)

def forward(
Expand Down Expand Up @@ -108,8 +108,8 @@ def forward(
outputs += residual

residual = outputs
outputs = self.encoder_attention_prenorm(outputs)
outputs, encoder_attn = self.encoder_attention(outputs, encoder_outputs, encoder_outputs, encoder_attn_mask)
outputs = self.decoder_attention_prenorm(outputs)
outputs, encoder_attn = self.decoder_attention(outputs, encoder_outputs, encoder_outputs, encoder_attn_mask)
outputs += residual

residual = outputs
Expand Down
3 changes: 1 addition & 2 deletions openspeech/decoders/transformer_transducer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.pad_id = pad_id
self.sos_id = sos_id
self.eos_id = eos_id
self.encoder_layers = nn.ModuleList([
self.decoder_layers = nn.ModuleList([
TransformerTransducerEncoderLayer(
model_dim,
d_ff,
Expand Down Expand Up @@ -124,7 +124,6 @@ def forward(
)

else: # train
inputs = inputs[inputs != self.eos_id].view(batch, -1)
target_lengths = inputs.size(1)

outputs = self.forward_step(
Expand Down
78 changes: 25 additions & 53 deletions openspeech/models/openspeech_transducer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,53 +88,27 @@ def collect_outputs(
input_lengths: torch.IntTensor,
targets: torch.IntTensor,
target_lengths: torch.IntTensor,
predictions: torch.Tensor = None,
) -> OrderedDict:
if predictions is None:
predictions = logits.max(-1)[1]
loss = self.criterion(
logits=logits,
targets=targets[:, 1:].contiguous().int(),
input_lengths=input_lengths.int(),
target_lengths=target_lengths.int(),
)

wer = self.wer_metric(targets[:, 1:], predictions)
cer = self.cer_metric(targets[:, 1:], predictions)

self.info({
f"{stage}_loss": loss,
f"{stage}_wer": wer,
f"{stage}_cer": cer,
"learning_rate": self.get_lr(),
})

return OrderedDict({
"loss": loss,
"wer": wer,
"cer": cer,
"predictions": predictions,
"targets": targets,
"logits": logits,
})
predictions = logits.max(-1)[1]

else:
wer = self.wer_metric(targets[:, 1:], predictions)
cer = self.cer_metric(targets[:, 1:], predictions)

self.info({
f"{stage}_wer": wer,
f"{stage}_cer": cer,
})

return OrderedDict({
"loss": None,
"wer": wer,
"cer": cer,
"predictions": predictions,
"targets": targets,
"logits": logits,
})
loss = self.criterion(
logits=logits,
targets=targets[:, 1:].contiguous().int(),
input_lengths=input_lengths.int(),
target_lengths=target_lengths.int(),
)

self.info({
f"{stage}_loss": loss,
"learning_rate": self.get_lr(),
})

return OrderedDict({
"loss": loss,
"predictions": predictions,
"targets": targets,
"logits": logits,
})

def _expand_for_joint(self, encoder_outputs: Tensor, decoder_outputs: Tensor) -> Tuple[Tensor, Tensor]:
input_length = encoder_outputs.size(1)
Expand Down Expand Up @@ -278,16 +252,15 @@ def validation_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
else:
encoder_outputs, output_lengths = self.encoder(inputs, input_lengths)

max_length = encoder_outputs.size(1)
decoder_outputs, _ = self.decoder(targets, target_lengths)
logits = self.joint(encoder_outputs, decoder_outputs)

predictions = self.decode(encoder_outputs, max_length)
return self.collect_outputs(
'val',
logits=None,
logits=logits,
input_lengths=output_lengths,
targets=targets,
target_lengths=target_lengths,
predictions=predictions,
)

def test_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
Expand All @@ -308,14 +281,13 @@ def test_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
else:
encoder_outputs, output_lengths = self.encoder(inputs, input_lengths)

max_length = encoder_outputs.size(1)
decoder_outputs, _ = self.decoder(targets, target_lengths)
logits = self.joint(encoder_outputs, decoder_outputs)

predictions = self.decode(encoder_outputs, max_length)
return self.collect_outputs(
'test',
logits=None,
logits=logits,
input_lengths=output_lengths,
targets=targets,
target_lengths=target_lengths,
predictions=predictions,
)
87 changes: 1 addition & 86 deletions openspeech/models/transformer_transducer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,89 +117,4 @@ def greedy_decode(self, encoder_outputs: Tensor, max_length: int) -> Tensor:

pred_tokens = torch.stack(pred_tokens, dim=1)

return torch.LongTensor(pred_tokens)

def forward(self, inputs: Tensor, input_lengths: Tensor) -> Dict[str, Tensor]:
r"""
Decode `encoder_outputs`.

Args:
inputs (torch.FloatTensor): A input sequence passed to encoders. Typically for inputs this will be a padded `FloatTensor` of size ``(batch, seq_length, dimension)``.
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``

Returns:
outputs (dict): Result of model predictions.
"""
encoder_outputs, _ = self.encoder(inputs, input_lengths)
max_length = encoder_outputs.size(1)

predictions = self.decode(encoder_outputs, max_length)
return {
"predictions": predictions,
"encoder_outputs": encoder_outputs,
}

def training_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
r"""
Forward propagate a `inputs` and `targets` pair for training.

Inputs:
batch (tuple): A train batch contains `inputs`, `targets`, `input_lengths`, `target_lengths`
batch_idx (int): The index of batch

Returns:
loss (torch.Tensor): loss for training
"""
return super(TransformerTransducerModel, self).training_step(batch, batch_idx)

def validation_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
r"""
Forward propagate a `inputs` and `targets` pair for validation.

Inputs:
batch (tuple): A train batch contains `inputs`, `targets`, `input_lengths`, `target_lengths`
batch_idx (int): The index of batch

Returns:
loss (torch.Tensor): loss for training
"""
inputs, targets, input_lengths, target_lengths = batch

encoder_outputs, _ = self.encoder(inputs, input_lengths)
max_length = encoder_outputs.size(1)

predictions = self.decode(encoder_outputs, max_length)
return self.collect_outputs(
'valid',
logits=None,
input_lengths=input_lengths,
targets=targets,
target_lengths=target_lengths,
predictions=predictions,
)

def test_step(self, batch: tuple, batch_idx: int) -> OrderedDict:
r"""
Forward propagate a `inputs` and `targets` pair for test.

Inputs:
batch (tuple): A train batch contains `inputs`, `targets`, `input_lengths`, `target_lengths`
batch_idx (int): The index of batch

Returns:
loss (torch.Tensor): loss for training
"""
inputs, targets, input_lengths, target_lengths = batch

encoder_outputs, _ = self.encoder(inputs, input_lengths)
max_length = encoder_outputs.size(1)

predictions = self.decode(encoder_outputs, max_length)
return self.collect_outputs(
'valid',
logits=None,
input_lengths=input_lengths,
targets=targets,
target_lengths=target_lengths,
predictions=predictions,
)
return torch.LongTensor(pred_tokens)