Skip to content

Commit

Permalink
name
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Sep 22, 2021
1 parent f50b09b commit 31198e8
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torchmetrics/functional/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,16 @@ def pit(
# calculate the metric matrix
batch_size, spk_num = target.shape[0:2]
metric_mtx = None
for tar_speech_idx in range(spk_num): # we have spk_num speeches in target in each sample
for est_speech_idx in range(spk_num): # we have spk_num speeches in preds in each sample
for target_idx in range(spk_num): # we have spk_num speeches in target in each sample
for preds_idx in range(spk_num): # we have spk_num speeches in preds in each sample
if metric_mtx is not None:
metric_mtx[:, tar_speech_idx, est_speech_idx] = metric_func(
preds[:, est_speech_idx, ...], target[:, tar_speech_idx, ...], **kwargs
metric_mtx[:, target_idx, preds_idx] = metric_func(
preds[:, preds_idx, ...], target[:, target_idx, ...], **kwargs
)
else:
first_ele = metric_func(preds[:, est_speech_idx, ...], target[:, tar_speech_idx, ...], **kwargs)
first_ele = metric_func(preds[:, preds_idx, ...], target[:, target_idx, ...], **kwargs)
metric_mtx = torch.empty((batch_size, spk_num, spk_num), dtype=first_ele.dtype, device=first_ele.device)
metric_mtx[:, tar_speech_idx, est_speech_idx] = first_ele
metric_mtx[:, target_idx, preds_idx] = first_ele

# find best
op = torch.max if eval_func == "max" else torch.min
Expand Down

0 comments on commit 31198e8

Please sign in to comment.