# focal loss 之 pytorch 实现

``````# coding=utf-8
import torch
import torch.nn.functional as F

from torch import nn
from torch.nn import CrossEntropyLoss
import numpy as np

class MultiFocalLoss(nn.Module):
"""
Focal_Loss= -1*alpha*((1-pt)**gamma)*log(pt)
Args:
num_class: number of classes
alpha: class balance factor shape=[num_class, ]
gamma: hyper-parameter
reduction: reduction type
"""

def __init__(self, num_class, alpha=None, gamma=2, reduction='mean'):
super(MultiFocalLoss, self).__init__()
self.num_class = num_class
self.gamma = gamma
self.reduction = reduction
self.smooth = 1e-4
self.alpha = alpha
if alpha is None:
self.alpha = torch.ones(num_class, ) - 0.5
elif isinstance(alpha, (int, float)):
self.alpha = torch.as_tensor([alpha] * num_class)
elif isinstance(alpha, (list, np.ndarray)):
self.alpha = torch.as_tensor(alpha)
if self.alpha.shape[0] != num_class:
raise RuntimeError('the length not equal to number of class')

def forward(self, logit, target):
"""
N: batch size C: class num
:param logit: [N, C] 或者 [N, C, d1, d2, d3 ......]
:param target: [N] 或 [N, d1, d2, d3 ........]
:return:
"""
# assert isinstance(self.alpha,torch.Tensor)
alpha = self.alpha.to(logit.device)
prob = F.softmax(logit, dim=1)

if prob.dim() > 2:
# used for 3d-conv:  N,C,d1,d2 -> N,C,m (m=d1*d2*...)
N, C = logit.shape[:2]
prob = prob.view(N, C, -1)
prob = prob.transpose(1, 2).contiguous()  # [N,C,d1*d2..] -> [N,d1*d2..,C]
prob = prob.view(-1, prob.size(-1))  # [N,d1*d2..,C]-> [N*d1*d2..,C]

ori_shp = target.shape
target = target.view(-1, 1)

prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nan
logpt = torch.log(prob)
# alpha_class = alpha.gather(0, target.squeeze(-1))
alpha_weight = alpha[target.squeeze().long()]
loss = -alpha_weight * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt

if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'none':
loss = loss.view(ori_shp)

return loss

if __name__ == "__main__":
batch_size, seq_len, num_class = 1, 2, 3

# 二维
Loss_Func = MultiFocalLoss(num_class=num_class, alpha=1, gamma=2, reduction='mean')
logits = torch.rand(batch_size, num_class, requires_grad=True)  # (batch_size, num_classes)
targets = torch.randint(0, num_class, size=(batch_size,))  # (batch_size, )
loss = Loss_Func(logits, targets)
print(loss)
loss.backward()

# 多维
Loss_Func = MultiFocalLoss(num_class=num_class, gamma=2.0, reduction='mean')
logits = torch.rand(batch_size, seq_len, num_class, requires_grad=True)  # (batch_size, num_classes)
targets = torch.randint(0, num_class, size=(batch_size, seq_len))  # (batch_size, )

loss = Loss_Func(logits.permute(0,2,1), targets)  # 类别必须放在第二个维度
print(loss)
loss.backward()
``````

THE END

)">