import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=None, reduction='mean'): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) self.reduction = reduction def forward(self, input, target): if input.dim()>2: input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W input = input.transpose(1,2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1,1) logpt = F.log_softmax(input,dim=1) logpt = logpt.gather(1,target) logpt = logpt.view(-1) pt = logpt.exp() if self.alpha is not None: if self.alpha.type()!=input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0,target.data.view(-1)) logpt = logpt * at loss = -1 * (1-pt)**self.gamma * logpt if self.reduction=='sum': loss = loss.sum() elif self.reduction=='mean': loss = loss.mean() return loss