tensorflow2.X和pytorch实现polyloss

polyloss介绍

polyloss是Cross-entropy loss和Focal loss的优化版本,PolyLoss在二维图像分类、实例分割、目标检测和三维目标检测任务上都明显优于Cross-entropy loss和Focal loss。

作者认为可以将常用的分类损失函数,如Cross-entropy loss和Focal loss,分解为一系列加权多项式基。

它们可以被分解为

j

=

1

n

α

j

(

1

P

t

)

j

sum_{j=1}^nalpha_j(1-P_t)^j

j=1nαj(1Pt)j的形式,其中

α

j

R

+

alpha_j∈R^+

αjR+为多项式系数,

P

t

P_t

Pt为目标类标签的预测概率。每个多项式基

(

1

P

t

)

j

(1-P_t)^j

(1Pt)j由相应的多项式系数

α

j

R

+

alpha_j∈R^+

αjR+进行加权,这使PolyLoss能够很容易地调整不同的多项式基。

  • α

    j

    =

    1

    /

    j

    alpha_j=1/j

    αj=1/j时,PolyLoss等价于常用的Cross-entropy loss,但这个系数分配可能不是最优的。

tensorflow2.X实现

import tensorflow as tf


def poly1_cross_entropy(epsilon=1.0):
    def _poly1_cross_entropy(y_true, y_pred):
        # pt, CE, and Poly1 have shape [batch].
        pt = tf.reduce_sum(y_true * tf.nn.softmax(y_pred), axis=-1)
        CE = tf.nn.softmax_cross_entropy_with_logits(y_true, y_pred)
        Poly1 = CE + epsilon * (1 - pt)
        loss = tf.reduce_mean(Poly1)
        return loss
    return _poly1_cross_entropy


def poly1_focal_loss(gamma=2.0, epsilon=1.0, alpha=0.25):
    def _poly1_focal_loss(y_true, y_pred):
        p = tf.math.sigmoid(y_pred)
        ce_loss = tf.nn.sigmoid_cross_entropy_with_logits(y_true, y_pred)
        pt = y_true * p + (1 - y_true) * (1 - p)
        FL = ce_loss * ((1 - pt) ** gamma)
        
        if alpha >= 0:
            alpha_t = alpha * y_true + (1 - alpha) * (1 - y_true)
            FL = alpha_t * FL
        Poly1 = FL + epsilon * tf.math.pow(1 - pt, gamma + 1)
        loss = tf.reduce_mean(Poly1)
        return loss
    return _poly1_focal_loss


pytorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F


class Poly1CrossEntropyLoss(nn.Module):
    def __init__(self,
                 num_classes: int,
                 epsilon: float = 1.0,
                 reduction: str = "none"):
        """
        Create instance of Poly1CrossEntropyLoss
        :param num_classes:
        :param epsilon:
        :param reduction: one of none|sum|mean, apply reduction to final loss tensor
        """
        super(Poly1CrossEntropyLoss, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.reduction = reduction
        return

    def forward(self, logits, labels):
        """
        Forward pass
        :param logits: tensor of shape [N, num_classes]
        :param labels: tensor of shape [N]
        :return: poly cross-entropy loss
        """
        labels_onehot = F.one_hot(labels, num_classes=self.num_classes).to(device=logits.device,
                                                                           dtype=logits.dtype)
        pt = torch.sum(labels_onehot * F.softmax(logits, dim=-1), dim=-1)
        CE = F.cross_entropy(input=logits, target=labels, reduction='none')
        poly1 = CE + self.epsilon * (1 - pt)
        if self.reduction == "mean":
            poly1 = poly1.mean()
        elif self.reduction == "sum":
            poly1 = poly1.sum()
        return poly1


class Poly1FocalLoss(nn.Module):
    def __init__(self,
                 num_classes: int,
                 epsilon: float = 1.0,
                 alpha: float = 0.25,
                 gamma: float = 2.0,
                 reduction: str = "none"):
        """
        Create instance of Poly1FocalLoss
        :param num_classes: number of classes
        :param epsilon: poly loss epsilon
        :param alpha: focal loss alpha
        :param gamma: focal loss gamma
        :param reduction: one of none|sum|mean, apply reduction to final loss tensor
        """
        super(Poly1FocalLoss, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        return

    def forward(self, logits, labels):
        """
        Forward pass
        :param logits: output of neural netwrok of shape [N, num_classes] or [N, num_classes, ...]
        :param labels: ground truth of shape [N] or [N, ...], NOT one-hot encoded
        :return: poly focal loss
        """
        # focal loss implementation taken from
        # https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py

        p = torch.sigmoid(logits)

        # if labels are of shape [N]
        # convert to one-hot tensor of shape [N, num_classes]
        if labels.ndim == 1:
            labels = F.one_hot(labels, num_classes=self.num_classes)

        # if labels are of shape [N, ...] e.g. segmentation task
        # convert to one-hot tensor of shape [N, num_classes, ...]
        else:
            labels = F.one_hot(labels.unsqueeze(1), self.num_classes).transpose(1, -1).squeeze_(-1)

        labels = labels.to(device=logits.device,
                           dtype=logits.dtype)

        ce_loss = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
        pt = labels * p + (1 - labels) * (1 - p)
        FL = ce_loss * ((1 - pt) ** self.gamma)

        if self.alpha >= 0:
            alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels)
            FL = alpha_t * FL

        poly1 = FL + self.epsilon * torch.pow(1 - pt, self.gamma + 1)

        if self.reduction == "mean":
            poly1 = poly1.mean()
        elif self.reduction == "sum":
            poly1 = poly1.sum()

        return poly1

总结

以resnet18为例,训练过程loss图被删了,所以只能在花朵识别在验证集中识别一下,结果正确率上升了6%左右,数据集如下:
链接:https://pan.baidu.com/s/1zs9U76OmGAIwbYr91KQxgg
提取码:bhjx
有兴趣的小伙伴可以自己尝试一下。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>