当前位置:   article > 正文

三分类SHAP图(特征标准化之后怎么画)

shap图

画三分类SHAP图出错

今天干了一件很蠢的事情,还耽误了很多时间,特此记录一下
我将数据标准化之后训练模型,然后将未标准化的数据作为输入计算了SHAP值,得出的结果显然不对。类似于下图这种
在这里插入图片描述
但是如果画图时将X_test输入作为参数,那么横坐标就对应的是标准化之后的值,所以我们可以先对X_test未经标准化时候制作一个copy版本X_test1,然后作为画图时候参数输入就可以正确画出SHAP图的横坐标了,也可以得到我们想要的信息。另外三分类shap values得到一个3维数据,有时候使用起来需要切片,比如画单个特征的shap图,但是画总体概览图时候不用。

import shap
X_test = pd.DataFrame(X_test,columns=x_test_cols)
explainer = shap.TreeExplainer(lgb_model)
shap_values = explainer.shap_values(X_test)  # 传入特征矩阵X,计算SHAP值
plt.figure()
#plt.rcParams['figure.dpi'] = 300 #分辨率
plt.title('LightGBM model SHAP values')
shap.summary_plot(shap_values, X_test,show=False)
plt.savefig(save_path+'\shap'+'lgb.png',dpi=300,bbox_inches = 'tight')

shap.initjs()
shap.dependence_plot('Na1', shap_values[1], X_test,interaction_index=None,show=False) #注意:如皋这么画,那么SHAP横坐标就是标准化之后的值
plt.axhline(y=0, color="red",linestyle='-')

#shap.force_plot(explainer.expected_value[0], shap_values[0][0,:], X_test.iloc[0,:])
#shap.force_plot(explainer.expected_value[0], shap_values[0], X_test)
#shap.dependence_plot("Na1", shap_values[1], X_test)

import os
shap_path = save_path +r'\class1'
if not os.path.isdir(shap_path):
    os.makedirs(shap_path)
for i in X_test.columns.values.tolist():
    plt.figure()
    shap.dependence_plot(i, shap_values[1], X_test1,interaction_index=None,show=False)
    plt.axhline(y=0, color="red",linestyle='-') #X_test1是X_test的一个未经标准化的复制版本
    plt.savefig(shap_path+ "\shap"+str(i)+'.png',dpi=300,bbox_inches = 'tight')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

下面是一张正确的结果图
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/97506
推荐阅读
相关标签
  

闽ICP备14008679号