scikit-learn 提供的绘图工具

2023-02-20 14:33 更新

Scikit-learn定义了一个简单的API,创建用于机器学习的可视化对象。该API的特点是无需重新计算即可进行快速绘图和视觉调整。在以下示例中,我们绘制了利用支持向量机算法产生的ROC曲线:

from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import load_wine

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)

svc_disp = plot_roc_curve(svc, X_test, y_test)


返回的svc_disp对象使我们可以在以后的图中继续使用已经计算出的SVC的ROC曲线。在本例中,svc_disp是一个 RocCurveDisplay,它将计算得到的值储存到称作roc_aucfpr,和tpr的属性中。接下来,我们训练一个随机森林分类器,并使用Display对象的plot 方法再次绘制先前计算的roc曲线。

import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier

rfc = RandomForestClassifier(random_state=42)
rfc.fit(X_train, y_train)

ax = plt.gca()
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)


请注意,我们传递alpha=0.8给绘图函数来调整曲线的透明度。

例子:
带有可视化API的ROC曲线
局部依赖的高级绘图
显示对象的可视化

5.1.1 函数

inspection.plot_partial_dependence(…[, …]) 部分依赖图。
metrics.plot_confusion_matrix(estimator, X, …) 绘制混淆矩阵。
metrics.plot_precision_recall_curve(…[, …]) 绘制二元分类器的精确度、召回率曲线。
metrics.plot_roc_curve(estimator, X, y, *) 绘制受试者工作特性(ROC)曲线。

5.1.2 可视化对象

inspection.PartialDependenceDisplay(…) 部分依赖图(PDP)可视化。
metrics.ConfusionMatrixDisplay(…[, …]) 混淆矩阵可视化。
metrics.PrecisionRecallDisplay(precision, …) 精确度、召回率可视化。
metrics.RocCurveDisplay(*, fpr, tpr[, …]) ROC曲线可视化。


以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号