赞
踩
今天干了一件很蠢的事情,还耽误了很多时间,特此记录一下
我将数据标准化之后训练模型,然后将未标准化的数据作为输入计算了SHAP值,得出的结果显然不对。类似于下图这种
但是如果画图时将X_test输入作为参数,那么横坐标就对应的是标准化之后的值,所以我们可以先对X_test未经标准化时候制作一个copy版本X_test1,然后作为画图时候参数输入就可以正确画出SHAP图的横坐标了,也可以得到我们想要的信息。另外三分类shap values得到一个3维数据,有时候使用起来需要切片,比如画单个特征的shap图,但是画总体概览图时候不用。
import shap X_test = pd.DataFrame(X_test,columns=x_test_cols) explainer = shap.TreeExplainer(lgb_model) shap_values = explainer.shap_values(X_test) # 传入特征矩阵X,计算SHAP值 plt.figure() #plt.rcParams['figure.dpi'] = 300 #分辨率 plt.title('LightGBM model SHAP values') shap.summary_plot(shap_values, X_test,show=False) plt.savefig(save_path+'\shap'+'lgb.png',dpi=300,bbox_inches = 'tight') shap.initjs() shap.dependence_plot('Na1', shap_values[1], X_test,interaction_index=None,show=False) #注意:如皋这么画,那么SHAP横坐标就是标准化之后的值 plt.axhline(y=0, color="red",linestyle='-') #shap.force_plot(explainer.expected_value[0], shap_values[0][0,:], X_test.iloc[0,:]) #shap.force_plot(explainer.expected_value[0], shap_values[0], X_test) #shap.dependence_plot("Na1", shap_values[1], X_test) import os shap_path = save_path +r'\class1' if not os.path.isdir(shap_path): os.makedirs(shap_path) for i in X_test.columns.values.tolist(): plt.figure() shap.dependence_plot(i, shap_values[1], X_test1,interaction_index=None,show=False) plt.axhline(y=0, color="red",linestyle='-') #X_test1是X_test的一个未经标准化的复制版本 plt.savefig(shap_path+ "\shap"+str(i)+'.png',dpi=300,bbox_inches = 'tight')
下面是一张正确的结果图
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。