Skip to content

Commit

Permalink
train with dict input (#2242)
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 authored Dec 14, 2023
1 parent d9d8bee commit d508589
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
19 changes: 7 additions & 12 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,14 @@ def __init__(
@torch.jit.ignore(drop=True)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
"""Frontend + Encoder + Decoder + Calc loss"""
speech = batch['feats'].to(device)
speech_lengths = batch['feats_lengths'].to(device)
text = batch['target'].to(device)
text_lengths = batch['target_lengths'].to(device)

assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
Expand Down
10 changes: 2 additions & 8 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,20 +456,14 @@ def batch_forward(model, batch, scaler, info_dict):
with torch.cuda.amp.autocast(enabled=dtype is not None,
dtype=dtype,
cache_enabled=False):
loss_dict = model(batch["feats"].to(device),
batch["feats_lengths"].to(device),
batch["target"].to(device),
batch["target_lengths"].to(device))
loss_dict = model(batch, device)
else:
# torch_ddp
# autocast context
# The more details about amp can be found in
# https://pytorch.org/docs/stable/notes/amp_examples.html
with torch.cuda.amp.autocast(scaler is not None):
loss_dict = model(batch["feats"].to(device),
batch["feats_lengths"].to(device),
batch["target"].to(device),
batch["target_lengths"].to(device))
loss_dict = model(batch, device)
info_dict['loss_dict'] = loss_dict

return info_dict
Expand Down

0 comments on commit d508589

Please sign in to comment.