# polyloss介绍

polyloss是Cross-entropy loss和Focal loss的优化版本，PolyLoss在二维图像分类、实例分割、目标检测和三维目标检测任务上都明显优于Cross-entropy loss和Focal loss。

j

=

1

n

α

j

(

1

P

t

)

j

sum_{j=1}^nalpha_j(1-P_t)^j

j=1nαj(1Pt)j的形式，其中

α

j

R

+

alpha_j∈R^+

αjR+为多项式系数，

P

t

P_t

Pt为目标类标签的预测概率。每个多项式基

(

1

P

t

)

j

(1-P_t)^j

(1Pt)j由相应的多项式系数

α

j

R

+

alpha_j∈R^+

αjR+进行加权，这使PolyLoss能够很容易地调整不同的多项式基。

• α

j

=

1

/

j

alpha_j=1/j

时，PolyLoss等价于常用的Cross-entropy loss，但这个系数分配可能不是最优的。

# tensorflow2.X实现

import tensorflow as tf

def poly1_cross_entropy(epsilon=1.0):
def _poly1_cross_entropy(y_true, y_pred):
# pt, CE, and Poly1 have shape [batch].
pt = tf.reduce_sum(y_true * tf.nn.softmax(y_pred), axis=-1)
CE = tf.nn.softmax_cross_entropy_with_logits(y_true, y_pred)
Poly1 = CE + epsilon * (1 - pt)
loss = tf.reduce_mean(Poly1)
return loss
return _poly1_cross_entropy

def poly1_focal_loss(gamma=2.0, epsilon=1.0, alpha=0.25):
def _poly1_focal_loss(y_true, y_pred):
p = tf.math.sigmoid(y_pred)
ce_loss = tf.nn.sigmoid_cross_entropy_with_logits(y_true, y_pred)
pt = y_true * p + (1 - y_true) * (1 - p)
FL = ce_loss * ((1 - pt) ** gamma)

if alpha >= 0:
alpha_t = alpha * y_true + (1 - alpha) * (1 - y_true)
FL = alpha_t * FL
Poly1 = FL + epsilon * tf.math.pow(1 - pt, gamma + 1)
loss = tf.reduce_mean(Poly1)
return loss
return _poly1_focal_loss



# pytorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class Poly1CrossEntropyLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
reduction: str = "none"):
"""
Create instance of Poly1CrossEntropyLoss
:param num_classes:
:param epsilon:
:param reduction: one of none|sum|mean, apply reduction to final loss tensor
"""
super(Poly1CrossEntropyLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.reduction = reduction
return

def forward(self, logits, labels):
"""
Forward pass
:param logits: tensor of shape [N, num_classes]
:param labels: tensor of shape [N]
:return: poly cross-entropy loss
"""
labels_onehot = F.one_hot(labels, num_classes=self.num_classes).to(device=logits.device,
dtype=logits.dtype)
pt = torch.sum(labels_onehot * F.softmax(logits, dim=-1), dim=-1)
CE = F.cross_entropy(input=logits, target=labels, reduction='none')
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return poly1

class Poly1FocalLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "none"):
"""
Create instance of Poly1FocalLoss
:param num_classes: number of classes
:param epsilon: poly loss epsilon
:param alpha: focal loss alpha
:param gamma: focal loss gamma
:param reduction: one of none|sum|mean, apply reduction to final loss tensor
"""
super(Poly1FocalLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
return

def forward(self, logits, labels):
"""
Forward pass
:param logits: output of neural netwrok of shape [N, num_classes] or [N, num_classes, ...]
:param labels: ground truth of shape [N] or [N, ...], NOT one-hot encoded
:return: poly focal loss
"""
# focal loss implementation taken from

p = torch.sigmoid(logits)

# if labels are of shape [N]
# convert to one-hot tensor of shape [N, num_classes]
if labels.ndim == 1:
labels = F.one_hot(labels, num_classes=self.num_classes)

# if labels are of shape [N, ...] e.g. segmentation task
# convert to one-hot tensor of shape [N, num_classes, ...]
else:
labels = F.one_hot(labels.unsqueeze(1), self.num_classes).transpose(1, -1).squeeze_(-1)

labels = labels.to(device=logits.device,
dtype=logits.dtype)

ce_loss = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
pt = labels * p + (1 - labels) * (1 - p)
FL = ce_loss * ((1 - pt) ** self.gamma)

if self.alpha >= 0:
alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels)
FL = alpha_t * FL

poly1 = FL + self.epsilon * torch.pow(1 - pt, self.gamma + 1)

if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()

return poly1


THE END