当前位置:   article > 正文

ROC曲线绘制(Python)_python将分训练组和验证组绘制roc曲线

python将分训练组和验证组绘制roc曲线

首先以支持向量机模型为例

先导入需要使用的包,我们将使用roc_curve这个函数绘制ROC曲线!

  1. from sklearn.svm import SVC
  2. from sklearn.metrics import roc_curve
  3. from sklearn.datasets import make_blobs
  4. from sklearn. model_selection import train_test_split
  5. import matplotlib.pyplot as plt
  6. %matplotlib inline

然后使用下面make_blobs函数,生成一个二分类的数据不平衡数据集;

使用train_test_split函数划分训练集和测试集数据;

训练SVC模型。

  1. X,y = make_blobs(n_samples=(4000,500), cluster_std=[7,2], random_state=0)
  2. X_train,X_test,y_train, y_test = train_test_split(X,y,random_state=0)
  3. clf = SVC(gamma=0.05).fit(X_train, y_train)

  1. fpr,tpr, thresholds = roc_curve(y_test,clf.decision_function(X_test))
  2. plt.plot(fpr,tpr,label='ROC')
  3. plt.xlabel('FPR')
  4. plt.ylabel('TPR')

从上面的代码可以看到,我们使用roc_curve函数生成三个变量,分别是fpr,tpr, thresholds,也就是假正例率(FPR)真正例率(TPR)和阈值。

而其中的fpr,tpr正是我们绘制ROC曲线的横纵坐标,于是我们以变量fpr为横坐标,tpr为纵坐标,绘制相应的ROC图像如下:

值得注意的是上面的支持向量机模型使用的decision_function函数,是自己所特有的,而其他模型不能直接使用。

比如说我们想要使用其他模型(例如决策树模型)的结果绘制ROC,直接套用上面的代码,会报错,会显示没有这个函数。

以决策树模型为例,解决上述问题(适用于除向量机外的模型)

导入决策树模型包以及训练模型的代码省略了,只需要手动改一改就行了,我们直接看绘图的代码!

  1. fpr,tpr, thresholds = roc_curve(y_test,clf.predict_proba(X_test)[:,1])
  2. plt.plot(fpr,tpr,label='ROC')
  3. plt.xlabel('FPR')
  4. plt.ylabel('TPR')

可以看到我们直接把只适用于支持向量机模型的函数decision_function更改成predict_proba(X_test)[:,1]就行了,让我们看看结果:

可以看到哈,决策树模型在这个数据集上的泛化能力不如支持向量机哈!!!学废了吗。

更好看的画法

  1. auc = roc_auc_score(y_test,clf.predict_proba(X_test)[:,1])
  2. # auc = roc_auc_score(y_test,clf.decision_function(X_test))
  3. fpr,tpr, thresholds = roc_curve(y_test,clf.decision_function(X_test))
  4. plt.plot(fpr,tpr,color='darkorange',label='ROC curve (area = %0.2f)' % auc)
  5. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  6. plt.xlim([0.0, 1.0])
  7. plt.ylim([0.0, 1.05])
  8. plt.xlabel('False Positive Rate')
  9. plt.ylabel('True Positive Rate')
  10. plt.title('Receiver operating characteristic example')
  11. plt.legend(loc="lower right")
  12. plt.savefig('suhan.jpg',dpi=800)
  13. plt.show()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/855773
推荐阅读
相关标签
  

闽ICP备14008679号