|
""" |
|
Permutation Invariant Training (PIT) loss |
|
|
|
Authors: |
|
* Jiatong Shi 2021 |
|
""" |
|
|
|
from itertools import permutations |
|
|
|
import numpy as np |
|
import torch |
|
|
|
__all__ = [ |
|
"pit_loss", |
|
] |
|
|
|
|
|
|
|
def create_length_mask(length, max_len, num_output, device): |
|
batch_size = len(length) |
|
mask = torch.zeros(batch_size, max_len, num_output) |
|
for i in range(batch_size): |
|
mask[i, : length[i], :] = 1 |
|
mask = mask.to(device) |
|
return mask |
|
|
|
|
|
|
|
def pit_loss_single_permute(output, label, length): |
|
bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") |
|
mask = create_length_mask(length, label.size(1), label.size(2), label.device) |
|
loss = bce_loss(output, label) |
|
loss = loss * mask |
|
loss = torch.sum(torch.mean(loss, dim=2), dim=1) |
|
loss = torch.unsqueeze(loss, dim=1) |
|
return loss |
|
|
|
|
|
def pit_loss(output, label, length): |
|
""" |
|
The Permutation Invariant Training loss |
|
|
|
Args: |
|
output (torch.FloatTensor): prediction in (batch_size, seq_len, num_class) |
|
label (torch.FloatTensor): label in the same shape as :code:`output` |
|
length (torch.LongTensor): the valid length of each instance. :code:`output` and :code:`label` |
|
share the same valid length |
|
|
|
Returns: |
|
tuple: |
|
|
|
1. loss (torch.FloatTensor) |
|
2. min_idx (int): the id with the minimum loss |
|
3. all the permutation |
|
""" |
|
num_output = label.size(2) |
|
device = label.device |
|
permute_list = [np.array(p) for p in permutations(range(num_output))] |
|
loss_list = [] |
|
for p in permute_list: |
|
label_perm = label[:, :, p] |
|
loss_perm = pit_loss_single_permute(output, label_perm, length) |
|
loss_list.append(loss_perm) |
|
loss = torch.cat(loss_list, dim=1) |
|
min_loss, min_idx = torch.min(loss, dim=1) |
|
loss = torch.sum(min_loss) / torch.sum(length.float().to(device)) |
|
return loss, min_idx, permute_list |
|
|
|
|
|
def get_label_perm(label, perm_idx, perm_list): |
|
batch_size = len(perm_idx) |
|
label_list = [] |
|
for i in range(batch_size): |
|
label_list.append(label[i, :, perm_list[perm_idx[i]]].data.cpu().numpy()) |
|
return torch.from_numpy(np.array(label_list)).float() |
|
|