pythonproject/FocalLoss.py

41 lines
1.3 KiB
Python
Raw Normal View History

2024-08-07 09:46:47 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=None, reduction='mean'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha,(float,int)):
self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,list):
self.alpha = torch.Tensor(alpha)
self.reduction = reduction
def forward(self, input, target):
if input.dim()>2:
input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
input = input.transpose(1,2) # N,C,H*W => N,H*W,C
input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
target = target.view(-1,1)
logpt = F.log_softmax(input,dim=1)
logpt = logpt.gather(1,target)
logpt = logpt.view(-1)
pt = logpt.exp()
if self.alpha is not None:
if self.alpha.type()!=input.data.type():
self.alpha = self.alpha.type_as(input.data)
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * at
loss = -1 * (1-pt)**self.gamma * logpt
if self.reduction=='sum':
loss = loss.sum()
elif self.reduction=='mean':
loss = loss.mean()
return loss