报错记录:TypeError: classification_report() takes 2 positional arguments but 3 were given
问题描述
- 今天在使用sklearn_crfsuite.metrics.flat_classification_report函数的时候突然报错:TypeError: classification_report() takes 2 positional arguments but 3 were given,这里对该函数进行了详细剖析,找到报错原因,并给出解决办法。
函数详细剖析
- 使用sklearn_crfsuite拓展包中的metrics查看指标
from sklearn_crfsuite import metrics
# y_true为真实标签、y_pred为预测标签、labels为想要查看指标的标签(通常去除'O')
print(metrics.flat_classification_report(y_true, y_pred, labels=sort_labels))
- 输入y_true和y_pred的输入格式通常如下图,即长度为序列个数,每个子列表为序列中每个字符的类别
- 跳转到sklearn_crfsuite.metrics.flat_classification_report函数中发现,实际调用的还是sklearn.metrics.classification_report,并使用@_flattens_y装饰器对输入数据进行展平后输入
- 展平后的数据形式如下图
- 从flat_classification_report中可以看到,它将label坐标作为位置参数传入,即未指定参数名称
- 但在下图的sklearn.metrics.classification_report函数中可以看到,其只接收y_true和y_pred两个位置参数,第三个*代表后面的参数必须指定参数名称,从而导致传入的labels成了多余的参数而报错
解决办法
- 既然都是调用的sklearn中的方法,那就自己重新实现以下过程,即展平后输入
from sklearn import metrics
y_true = [label for y in y_true for label in y]
y_pred = [label for y in y_pred for label in y]
print(metrics.classification_report(
y_true, y_pred, labels=sort_labels
))
- 或者使用sklearn_crfsuite里的方式进行展平
from sklearn import metrics
from itertools import chain
y_true = list(chain.from_iterable(y_true))
y_pred = list(chain.from_iterable(y_pred))
print(metrics.classification_report(
y_true, y_pred, labels=sort_labels
))
本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
二维码