报错记录:TypeError: classification_report() takes 2 positional arguments but 3 were given

问题描述

  • 今天在使用sklearn_crfsuite.metrics.flat_classification_report函数的时候突然报错:TypeError: classification_report() takes 2 positional arguments but 3 were given,这里对该函数进行了详细剖析,找到报错原因,并给出解决办法。
    在这里插入图片描述

函数详细剖析

  • 使用sklearn_crfsuite拓展包中的metrics查看指标
from sklearn_crfsuite import metrics
# y_true为真实标签、y_pred为预测标签、labels为想要查看指标的标签(通常去除'O')
print(metrics.flat_classification_report(y_true, y_pred, labels=sort_labels))
  • 输入y_true和y_pred的输入格式通常如下图,即长度为序列个数,每个子列表为序列中每个字符的类别
    在这里插入图片描述
  • 跳转到sklearn_crfsuite.metrics.flat_classification_report函数中发现,实际调用的还是sklearn.metrics.classification_report,并使用@_flattens_y装饰器对输入数据进行展平后输入
    在这里插入图片描述
    在这里插入图片描述
  • 展平后的数据形式如下图
    在这里插入图片描述
  • 从flat_classification_report中可以看到,它将label坐标作为位置参数传入,即未指定参数名称
  • 但在下图的sklearn.metrics.classification_report函数中可以看到,其只接收y_true和y_pred两个位置参数,第三个*代表后面的参数必须指定参数名称,从而导致传入的labels成了多余的参数而报错
    在这里插入图片描述

解决办法

  • 既然都是调用的sklearn中的方法,那就自己重新实现以下过程,即展平后输入
from sklearn import metrics

y_true = [label for y in y_true for label in y]
y_pred = [label for y in y_pred for label in y]

print(metrics.classification_report(
    y_true, y_pred, labels=sort_labels
))
  • 或者使用sklearn_crfsuite里的方式进行展平
from sklearn import metrics
from itertools import chain

y_true = list(chain.from_iterable(y_true))
y_pred = list(chain.from_iterable(y_pred))

print(metrics.classification_report(
    y_true, y_pred, labels=sort_labels
))

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>