|
|
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
|
|
|
|
class AutomaticWeightedLoss(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, num=2, args=None): |
|
super(AutomaticWeightedLoss, self).__init__() |
|
if args is None or args.use_awl: |
|
params = torch.ones(num, requires_grad=True) |
|
self.params = torch.nn.Parameter(params) |
|
else: |
|
params = torch.ones(num, requires_grad=False) |
|
self.params = torch.nn.Parameter(params, requires_grad=False) |
|
|
|
def forward(self, *x): |
|
loss_sum = 0 |
|
for i, loss in enumerate(x): |
|
loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2) |
|
return loss_sum |
|
|