赞
踩
SHAP Tutorial
本文主要介绍:
SHAP的目标就是通过计算每个样本中每一个特征对prediction的贡献, 来对模型结果做解释。在合作博弈论的启发下SHAP构建一个加性的解释模型,所有的特征都视为“贡献者”。对于每个预测样本,模型都产生一个预测值,SHAP值就是该样本中每个特征所分配到的数值。
设第 i i i个样本为 x i x_i xi,第 i i i个样本的第 j j j个特征为 x i j x_i^j xij,模型对该样本的预测值为 y i y_i yi,整个模型的基线(通常是所有样本的目标变量的均值)为 y b a s e y_{base} ybase,那么SHAP值服从以下等式:
y i = y b a s e + f ( x i 1 ) + f ( x i 2 ) + . . . + f ( x i j ) y_i = y_{base} + f(x_i^1)+f(x_i^2)+...+f(x_i^j) yi=ybase+f(xi1)+f(xi2)+...+f(xij)
其中 f ( x i j ) f(x_i^j) f(xij)为 x i j x_i^j xij的SHAP值。直观上看, f ( x i 1 ) f(x_i^1) f(xi1)就是第 i i i个样本中第1个特征对最终预测值 y i y_i yi的贡献值,当 f ( x i 1 ) > 0 f(x_i^1)>0 f(xi1)>0,说明该特征提升了预测值,也正向作用;反之,说明该特征使得预测值降低,有反作用。上式是也是通过拟合的方式获得 f ( x i j ) f(x_i^j) f(xij)的值。具体的是如何拟合的这里不做要求。
这个教材以Titantic数据为例子进行说明
下面简单看下titantic的特征:
PassengerId 乘客ID
Survived 获救与否(0死亡,1生存)
Pclass 乘客等级(1/2/3等舱位)
Name 乘客姓名
Sex 性别
Age 年龄
SibSp 堂兄弟/妹个数
Parch 父母与小孩个数
Ticket 船票信息
Fare 票价
Cabin 客舱
Embarked 登船港口
这是一个二分类问题,其中Survived是要预测的值, Pclass,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked为特征
我们将以集成树模型(LightGBM为例来讲解)
注:我们尽量简化这个问题,方便大家理解shap是如何应用的。
#拉取数据
!git clone https://github.com/pangpang97/shap_tutorial
Cloning into 'shap_tutorial'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects: 100% (12/12), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 12 (delta 1), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (12/12), done.
#引入需要的包
import pandas as pd
import numpy as np
all_data = pd.read_csv('shap_tutorial/train_titantic.csv')
all_data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 891 entries, 0 to 890 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 PassengerId 891 non-null int64 1 Survived 891 non-null int64 2 Pclass 891 non-null int64 3 Name 891 non-null object 4 Sex 891 non-null object 5 Age 714 non-null float64 6 SibSp 891 non-null int64 7 Parch 891 non-null int64 8 Ticket 891 non-null object 9 Fare 891 non-null float64 10 Cabin 204 non-null object 11 Embarked 889 non-null object dtypes: float64(2), int64(5), object(5) memory usage: 83.7+ KB
#删去暂时不要的列
all_data.drop(['PassengerId','Name'], axis=1, inplace=True)
#填充缺失值
all_data['Age'].fillna(0, inplace=True)
all_data['Cabin'].fillna('UNK', inplace=True)
all_data['Embarked'].fillna('UNK', inplace=True)
all_data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 891 entries, 0 to 890 Data columns (total 10 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Survived 891 non-null int64 1 Pclass 891 non-null int64 2 Sex 891 non-null object 3 Age 891 non-null float64 4 SibSp 891 non-null int64 5 Parch 891 non-null int64 6 Ticket 891 non-null object 7 Fare 891 non-null float64 8 Cabin 891 non-null object 9 Embarked 891 non-null object dtypes: float64(2), int64(4), object(4) memory usage: 69.7+ KB
#把Sex,Ticket,Cabin和Embarked做一个lebel encoding
from sklearn.preprocessing import LabelEncoder
cate_cols = ['Sex','Ticket','Cabin','Embarked']
for cc in cate_cols:
enc = LabelEncoder()
all_data[cc] = enc.fit_transform(all_data[cc])
#拆分下训练集和测试集 from sklearn.model_selection import train_test_split train, test = train_test_split(all_data,test_size=0.2) #利用LightGBM训练模型 import lightgbm as lgb from sklearn import metrics params = {'objective': 'binary', 'metric': 'binary_logloss', 'num_round': 80, 'verbose':1 } num_round = params.pop('num_round',1000) xtrain = lgb.Dataset(train.drop(columns=['Survived']), train['Survived'],free_raw_data=False) xeval = lgb.Dataset(test.drop(columns=['Survived']), test['Survived'],free_raw_data=False) evallist = [xtrain, xeval] clf = lgb.train(params, xtrain, num_round, valid_sets=evallist) ytrain = np.where(clf.predict(train.drop(columns=['Survived']))>=0.5, 1,0) ytest = np.where(clf.predict(test.drop(columns=['Survived']))>=0.5, 1,0) print("train classification report") print(metrics.classification_report(train['Survived'], ytrain)) print('*'*60) print("test classification report") print(metrics.classification_report(test['Survived'], ytest))
[1] training's binary_logloss: 0.616752 valid_1's binary_logloss: 0.628188 [2] training's binary_logloss: 0.577503 valid_1's binary_logloss: 0.59757 [3] training's binary_logloss: 0.545282 valid_1's binary_logloss: 0.573299 [4] training's binary_logloss: 0.517894 valid_1's binary_logloss: 0.55145 [5] training's binary_logloss: 0.49365 valid_1's binary_logloss: 0.53399 [6] training's binary_logloss: 0.472167 valid_1's binary_logloss: 0.518513 [7] training's binary_logloss: 0.453641 valid_1's binary_logloss: 0.506068 [8] training's binary_logloss: 0.437905 valid_1's binary_logloss: 0.494044 [9] training's binary_logloss: 0.424174 valid_1's binary_logloss: 0.485643 [10] training's binary_logloss: 0.410632 valid_1's binary_logloss: 0.477679 [11] training's binary_logloss: 0.398761 valid_1's binary_logloss: 0.470235 [12] training's binary_logloss: 0.388724 valid_1's binary_logloss: 0.464987 [13] training's binary_logloss: 0.379662 valid_1's binary_logloss: 0.460591 [14] training's binary_logloss: 0.370902 valid_1's binary_logloss: 0.457459 [15] training's binary_logloss: 0.363174 valid_1's binary_logloss: 0.455261 [16] training's binary_logloss: 0.356024 valid_1's binary_logloss: 0.451337 [17] training's binary_logloss: 0.348805 valid_1's binary_logloss: 0.447478 [18] training's binary_logloss: 0.340424 valid_1's binary_logloss: 0.444287 [19] training's binary_logloss: 0.3332 valid_1's binary_logloss: 0.439115 [20] training's binary_logloss: 0.326891 valid_1's binary_logloss: 0.436567 [21] training's binary_logloss: 0.321057 valid_1's binary_logloss: 0.434214 [22] training's binary_logloss: 0.315413 valid_1's binary_logloss: 0.433245 [23] training's binary_logloss: 0.310579 valid_1's binary_logloss: 0.434182 [24] training's binary_logloss: 0.305809 valid_1's binary_logloss: 0.43319 [25] training's binary_logloss: 0.301533 valid_1's binary_logloss: 0.432078 [26] training's binary_logloss: 0.295784 valid_1's binary_logloss: 0.430593 [27] training's binary_logloss: 0.290461 valid_1's binary_logloss: 0.428543 [28] training's binary_logloss: 0.285622 valid_1's binary_logloss: 0.426985 [29] training's binary_logloss: 0.279571 valid_1's binary_logloss: 0.426044 [30] training's binary_logloss: 0.274883 valid_1's binary_logloss: 0.427347 [31] training's binary_logloss: 0.270582 valid_1's binary_logloss: 0.427791 [32] training's binary_logloss: 0.265887 valid_1's binary_logloss: 0.428961 [33] training's binary_logloss: 0.260971 valid_1's binary_logloss: 0.430992 [34] training's binary_logloss: 0.254612 valid_1's binary_logloss: 0.429185 [35] training's binary_logloss: 0.250871 valid_1's binary_logloss: 0.429506 [36] training's binary_logloss: 0.245969 valid_1's binary_logloss: 0.425777 [37] training's binary_logloss: 0.24145 valid_1's binary_logloss: 0.425065 [38] training's binary_logloss: 0.237223 valid_1's binary_logloss: 0.423375 [39] training's binary_logloss: 0.233457 valid_1's binary_logloss: 0.42279 [40] training's binary_logloss: 0.229837 valid_1's binary_logloss: 0.421586 [41] training's binary_logloss: 0.2258 valid_1's binary_logloss: 0.419546 [42] training's binary_logloss: 0.222053 valid_1's binary_logloss: 0.420543 [43] training's binary_logloss: 0.218142 valid_1's binary_logloss: 0.419648 [44] training's binary_logloss: 0.214406 valid_1's binary_logloss: 0.417376 [45] training's binary_logloss: 0.21062 valid_1's binary_logloss: 0.417219 [46] training's binary_logloss: 0.207693 valid_1's binary_logloss: 0.418382 [47] training's binary_logloss: 0.204459 valid_1's binary_logloss: 0.420574 [48] training's binary_logloss: 0.201661 valid_1's binary_logloss: 0.420458 [49] training's binary_logloss: 0.198652 valid_1's binary_logloss: 0.420256 [50] training's binary_logloss: 0.195849 valid_1's binary_logloss: 0.41788 [51] training's binary_logloss: 0.192828 valid_1's binary_logloss: 0.419856 [52] training's binary_logloss: 0.189455 valid_1's binary_logloss: 0.419239 [53] training's binary_logloss: 0.186862 valid_1's binary_logloss: 0.418061 [54] training's binary_logloss: 0.184144 valid_1's binary_logloss: 0.420203 [55] training's binary_logloss: 0.18186 valid_1's binary_logloss: 0.419781 [56] training's binary_logloss: 0.179336 valid_1's binary_logloss: 0.418251 [57] training's binary_logloss: 0.176953 valid_1's binary_logloss: 0.418373 [58] training's binary_logloss: 0.174188 valid_1's binary_logloss: 0.421177 [59] training's binary_logloss: 0.171624 valid_1's binary_logloss: 0.422029 [60] training's binary_logloss: 0.169554 valid_1's binary_logloss: 0.421043 [61] training's binary_logloss: 0.167043 valid_1's binary_logloss: 0.420784 [62] training's binary_logloss: 0.164732 valid_1's binary_logloss: 0.421378 [63] training's binary_logloss: 0.162674 valid_1's binary_logloss: 0.421023 [64] training's binary_logloss: 0.161026 valid_1's binary_logloss: 0.422021 [65] training's binary_logloss: 0.159152 valid_1's binary_logloss: 0.423376 [66] training's binary_logloss: 0.157098 valid_1's binary_logloss: 0.423112 [67] training's binary_logloss: 0.154923 valid_1's binary_logloss: 0.424078 [68] training's binary_logloss: 0.152742 valid_1's binary_logloss: 0.423512 [69] training's binary_logloss: 0.150701 valid_1's binary_logloss: 0.422217 [70] training's binary_logloss: 0.149009 valid_1's binary_logloss: 0.420999 [71] training's binary_logloss: 0.147453 valid_1's binary_logloss: 0.421606 [72] training's binary_logloss: 0.145442 valid_1's binary_logloss: 0.4223 [73] training's binary_logloss: 0.143639 valid_1's binary_logloss: 0.421754 [74] training's binary_logloss: 0.14174 valid_1's binary_logloss: 0.421546 [75] training's binary_logloss: 0.13988 valid_1's binary_logloss: 0.423675 [76] training's binary_logloss: 0.138371 valid_1's binary_logloss: 0.423968 [77] training's binary_logloss: 0.136794 valid_1's binary_logloss: 0.424537 [78] training's binary_logloss: 0.134842 valid_1's binary_logloss: 0.425823 [79] training's binary_logloss: 0.132775 valid_1's binary_logloss: 0.424314 [80] training's binary_logloss: 0.131154 valid_1's binary_logloss: 0.42698 train classification report precision recall f1-score support 0 0.97 0.99 0.98 439 1 0.99 0.95 0.97 273 accuracy 0.98 712 macro avg 0.98 0.97 0.97 712 weighted avg 0.98 0.98 0.98 712 ************************************************************ test classification report precision recall f1-score support 0 0.83 0.93 0.88 110 1 0.86 0.70 0.77 69 accuracy 0.84 179 macro avg 0.84 0.81 0.82 179 weighted avg 0.84 0.84 0.83 179
!pip install shap
import warnings
warnings.filterwarnings("ignore")
import shap
shap.initjs()
Collecting shap Downloading shap-0.40.0-cp37-cp37m-manylinux2010_x86_64.whl (564 kB) [K |████████████████████████████████| 564 kB 3.6 MB/s [?25hRequirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from shap) (1.4.1) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from shap) (1.0.1) Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap) (1.3.0) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from shap) (1.19.5) Collecting slicer==0.0.7 Downloading slicer-0.0.7-py3-none-any.whl (14 kB) Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap) (21.3) Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap) (4.62.3) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from shap) (1.1.5) Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap) (0.51.2) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>20.9->shap) (3.0.6) Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap) (0.34.0) Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba->shap) (57.4.0) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2018.9) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2.8.2) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (1.1.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (3.0.0) Installing collected packages: slicer, shap Successfully installed shap-0.40.0 slicer-0.0.7
#如果数据量大,这个运行的会非常慢
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(train.drop(columns=['Survived'])) #获取shap value
np.array(shap_values).shape #看一下shap value的dim
(2, 712, 9)
shap值是一个三维的数据。一个样本有两个shap值。
第一个维度控制的是选择哪个类别的shap值,第一个维度是0表示0(negative)类,第一个维度是1表示1(positive)类。
后面两个维度就是每个样本和每个特征的shap值。
(shap_values[0] == -1* shap_values[1]).all()
True
可以看到0类的shap值和1类的shap值是相反的
可能的应用场景:
我们以看第一个样本的正类的shap值为例
train.drop(columns=['Survived']).iloc[0].T
Pclass 3.0000
Sex 1.0000
Age 20.0000
SibSp 1.0000
Parch 1.0000
Ticket 187.0000
Fare 15.7417
Cabin 147.0000
Embarked 0.0000
Name: 622, dtype: float64
#查看单个样本的特征贡献的第一种方法
shap.initjs()# colab需要在每个cell上运行这个命令,如果你是jupyter notebook或jupyter lab可以把这行注释掉
shap.plots.force(explainer.expected_value[1],shap_values[1][0],train.drop(columns=['Survived']).iloc[0])
上面的图表示的是: 对一个样本来说,各个特征是如何把预测值从base value推到f(x)(最终的输出值)的。说明了各个特征的贡献。红色的表示向正向(往右)推动预测值;蓝色的表示向正负(往左)推动预测值的。每个特征图块的大小表示shap值的大小。base value是指的模型对训练集的所有样本的(这个参考资料中都没有明确说明 这个是我试出来的)预测值的均值。对于回归问题就是均值,如果对于分类问题是每个类别的均值。是explainer.expected_value
的值。
#看一下base value
explainer.expected_value
[0.908824118858436, -0.908824118858436]
因为是分类问题,所以explainer.expected_value有两个值,第一个是0类预测均值,第二个是1类的均值。 SHAP对于xgboost和lightgbm取的是对数几率转换. 还有个问题需要注意下
y_train_prob = clf.predict(train.drop(columns=['Survived']))
print('shap base value:', explainer.expected_value[1], ' 取log的预测值的均值:',np.log(y_train_prob/ (1 - y_train_prob)).mean())
shap base value: -0.908824118858436 取log的预测值的均值: -0.9088241188584363
可以看到上面的结果,shap的base value和取log的预测值的均值是基本相等的。 如果看的是test的shap 取log的预测值就会和base value有明显差异。
#查看单个样本的特征贡献的第二种方法
shap.plots._waterfall.waterfall_legacy(explainer.expected_value[1], shap_values[1][0],train.drop(columns=['Survived']).iloc[0])
这个图的意思跟shap.plots.force()
方法是一样的,但是我感觉看起来更加的直观。
这个方法是替代shap.plots.waterfall()
和shap.plots.waterfall()
方法的。 这俩方法会报错。
#查看单个样本的特征贡献的第三种方法
shap.bar_plot(shap_values[1][0],train.drop(columns=['Survived']).iloc[0])
这个图跟前面两个是很相似的。但是没有标记base value和f(x)(最终的预测值)
shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values[1],train.drop(columns=['Survived']))#我们看下训练集所有样本的特征贡献情况
这个图是训练集的所有样本的特征贡献情况,用鼠标华东可以看每一个样本的情况。
横坐标:可以按照多种指标将样本排序,包括: 样本的相似性(就是把相似的样本排在一起),样本输出结果,原始样本顺序,各个特征值;
纵坐标:可以按照最终的预测值以及各个各种的贡献值排序
可能的应用场景:
shap.summary_plot(shap_values, train.drop(columns=['Survived']),plot_type="bar")
这既是利用shap值做的特征重要性的图。 这个是每个样本的0类和1类的shap值分别取绝对值再求平均得到的。
这里分别展示了0类和1类的特征重要性。通过上面的分析知道,0类的shap值和1类的shap值的绝对值是一样的。所以理论上只看一个就行了。
shap.summary_plot(shap_values[1], train.drop(columns=['Survived']),plot_type="bar")
这个是我们经常用到的特征重要性。
这个的计算方式和上面是一样的,只不过只取了1类的shap值。
shap.summary_plot(shap_values[1], train.drop(columns=['Survived']))
这个可以看作是特征重要性的细致版本。这里为每个样本绘制其每个特征的SHAP值,这可以更好地理解整体模式,并允许发现预测异常值。每一行代表一个特征,横坐标为SHAP值。一个点代表一个样本,颜色表示特征值(红色高,蓝色低)。
例如:性别(Sex)为例,Sex=1的shap值要小于Sex=0的shap值。联系到这个问题可以分析出来,Sex=0(即性别是female)更容易活下来(Survivied=1)。同理发现 乘客等级越低(Pclass值越大,shap值越小,越难以活下来。
可能的应用场景:
shap.dependence_plot('Age', shap_values[1], train.drop(columns=['Survived']),interaction_index=None)
如图,横坐标是特征(Age)的值,纵坐标是对应的shap值,整体的趋势是先下降后升高的。这说明小孩子和老人更容易活下来。
可能的应用场景:
shap.dependence_plot('Age', shap_values[1], train.drop(columns=['Survived']),interaction_index='Pclass')
如图,横坐标是特征(Age)的值,纵坐标是对应的shap值,颜色的深浅代表Pclass。
从图中可以看出来,在乘客级别较低的地方,年纪大的人更容易活下来。
SHAP不仅仅有TreeExplainer,还有其他的Explainer:
import pandas as pd import numpy as np all_data = pd.read_csv('shap_tutorial/train_titantic.csv') #删去暂时不要的列 all_data.drop(['PassengerId','Name'], axis=1, inplace=True) #填充缺失值 all_data['Age'].fillna(0, inplace=True) all_data['Cabin'].fillna('UNK', inplace=True) all_data['Embarked'].fillna('UNK', inplace=True) #把Sex,Ticket,Cabin和Embarked做一个lebel encoding from sklearn.preprocessing import LabelEncoder cate_cols = ['Sex','Ticket','Cabin','Embarked'] cate_features = {} cate_cnt ={} for cc in cate_cols: enc = LabelEncoder() all_data[cc] = enc.fit_transform(all_data[cc]) cate_cnt[cc] = all_data[cc].max() cate_cnt['Parch'] = all_data['Parch'].max() cate_cnt['Pclass'] = all_data['Pclass'].max() cate_cnt['SibSp'] = all_data['SibSp'].max()
cate_cnt
{'Cabin': 147,
'Embarked': 3,
'Parch': 6,
'Pclass': 3,
'Sex': 1,
'SibSp': 8,
'Ticket': 680}
#拆分下训练集和测试集
use_cols = ['Cabin','Embarked','Parch','Pclass','Sex','SibSp','Ticket']
from sklearn.model_selection import train_test_split
y = all_data['Survived']
X = all_data.copy()
X.drop(columns=['Survived'], inplace=True)
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2)
import tensorflow as tf print(tf.__version__) from tensorflow.keras.layers import * from tensorflow.keras import Input from tensorflow.keras import Model def mlp(train, use_cols): input_list = [] emb_list = [] for c in train.columns: input_list.append(Input(shape=(1,),name=c)) if c in use_cols: #需要做embedding emb_feature = Flatten()(Embedding(input_dim=train[c].max()+1, output_dim=4)(input_list[-1])) else: emb_feature = input_list[-1] emb_list.append(emb_feature) #concate all input all_in = Concatenate()(emb_list) hidden = Dense(32,activation='relu')(all_in) hidden = Dense(16, activation='relu')(hidden) hidden = Dense(8, activation='relu')(hidden) y = Dense(1, activation='sigmoid')(hidden) #model model = Model(inputs=input_list, outputs=[y]) return model
2.7.0
model = mlp(X, use_cols)
model.compile(loss='binary_crossentropy', optimizer='adam',metrics=['acc'])
model.fit([X_train[c] for c in X_train.columns], y_train, epochs=18, batch_size=32, validation_data=([X_test[c] for c in X_test.columns], y_test))
Epoch 1/18 23/23 [==============================] - 4s 26ms/step - loss: 1.4307 - acc: 0.4789 - val_loss: 0.7328 - val_acc: 0.6536 Epoch 2/18 23/23 [==============================] - 0s 8ms/step - loss: 0.6792 - acc: 0.6362 - val_loss: 0.6359 - val_acc: 0.6704 Epoch 3/18 23/23 [==============================] - 0s 9ms/step - loss: 0.6378 - acc: 0.6657 - val_loss: 0.6211 - val_acc: 0.6760 Epoch 4/18 23/23 [==============================] - 0s 11ms/step - loss: 0.6203 - acc: 0.6924 - val_loss: 0.6124 - val_acc: 0.6927 Epoch 5/18 23/23 [==============================] - 0s 10ms/step - loss: 0.6074 - acc: 0.6882 - val_loss: 0.6106 - val_acc: 0.7039 Epoch 6/18 23/23 [==============================] - 0s 12ms/step - loss: 0.5984 - acc: 0.6924 - val_loss: 0.6007 - val_acc: 0.7151 Epoch 7/18 23/23 [==============================] - 0s 12ms/step - loss: 0.5701 - acc: 0.7233 - val_loss: 0.5816 - val_acc: 0.7207 Epoch 8/18 23/23 [==============================] - 0s 14ms/step - loss: 0.5453 - acc: 0.7374 - val_loss: 0.5581 - val_acc: 0.7542 Epoch 9/18 23/23 [==============================] - 0s 9ms/step - loss: 0.5228 - acc: 0.7556 - val_loss: 0.5418 - val_acc: 0.7039 Epoch 10/18 23/23 [==============================] - 0s 7ms/step - loss: 0.4760 - acc: 0.7921 - val_loss: 0.5150 - val_acc: 0.7598 Epoch 11/18 23/23 [==============================] - 0s 7ms/step - loss: 0.4496 - acc: 0.8076 - val_loss: 0.4834 - val_acc: 0.8156 Epoch 12/18 23/23 [==============================] - 0s 10ms/step - loss: 0.4187 - acc: 0.8371 - val_loss: 0.4999 - val_acc: 0.7821 Epoch 13/18 23/23 [==============================] - 0s 12ms/step - loss: 0.3951 - acc: 0.8581 - val_loss: 0.4879 - val_acc: 0.8045 Epoch 14/18 23/23 [==============================] - 0s 10ms/step - loss: 0.3624 - acc: 0.8736 - val_loss: 0.4686 - val_acc: 0.8101 Epoch 15/18 23/23 [==============================] - 0s 13ms/step - loss: 0.3317 - acc: 0.8862 - val_loss: 0.4619 - val_acc: 0.8436 Epoch 16/18 23/23 [==============================] - 0s 11ms/step - loss: 0.2931 - acc: 0.9059 - val_loss: 0.4536 - val_acc: 0.8212 Epoch 17/18 23/23 [==============================] - 0s 9ms/step - loss: 0.2752 - acc: 0.9157 - val_loss: 0.4332 - val_acc: 0.8492 Epoch 18/18 23/23 [==============================] - 0s 11ms/step - loss: 0.2314 - acc: 0.9382 - val_loss: 0.4477 - val_acc: 0.8492 <keras.callbacks.History at 0x7f47edcbd9d0>
from sklearn import metrics
ytrain_pred = np.where(model.predict([X_train[c] for c in X_train.columns])>=0.5, 1,0)
ytest_pred = np.where(model.predict([X_test[c] for c in X_test.columns])>=0.5, 1,0)
print("train classification report")
print(metrics.classification_report(y_train, ytrain_pred))
print('*'*60)
print("test classification report")
print(metrics.classification_report(y_test, ytest_pred))
train classification report precision recall f1-score support 0 0.91 0.99 0.95 437 1 0.97 0.84 0.90 275 accuracy 0.93 712 macro avg 0.94 0.91 0.92 712 weighted avg 0.93 0.93 0.93 712 ************************************************************ test classification report precision recall f1-score support 0 0.85 0.93 0.89 112 1 0.86 0.72 0.78 67 accuracy 0.85 179 macro avg 0.85 0.82 0.83 179 weighted avg 0.85 0.85 0.85 179
!pip install shap
import warnings
warnings.filterwarnings("ignore")
import shap
shap.initjs()
Requirement already satisfied: shap in /usr/local/lib/python3.7/dist-packages (0.40.0) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from shap) (1.4.1) Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap) (21.3) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from shap) (1.1.5) Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap) (0.51.2) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from shap) (1.0.1) Requirement already satisfied: slicer==0.0.7 in /usr/local/lib/python3.7/dist-packages (from shap) (0.0.7) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from shap) (1.19.5) Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap) (1.3.0) Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap) (4.62.3) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>20.9->shap) (3.0.6) Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap) (0.34.0) Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba->shap) (57.4.0) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2018.9) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2.8.2) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (1.1.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (3.0.0)
def f(X):
return model.predict([X[:,i] for i in range(X.shape[1])]).flatten()
#这个Explainer非常的慢,大概需要运行3分钟。如果数据量更大,时间会更长。这个比TreeExplainer慢很多。
explainer = shap.KernelExplainer(f, X_train.iloc[:10]) #用十个样本去评估,如果用所有的数据,会很慢
shap_values = explainer.shap_values(X_train) #获取shap value
np.array(shap_values).shape #看一下shap value的dim
100%| | 0/712 [00:00<?, ?it/s]
(712, 9)
print(X_train.iloc[0,:])
print('Survived ', y_train.iloc[0])
Pclass 3.0000
Sex 1.0000
Age 29.0000
SibSp 1.0000
Parch 0.0000
Ticket 315.0000
Fare 7.0458
Cabin 147.0000
Embarked 2.0000
Name: 477, dtype: float64
Survived 0
注意这里是2维的
shap.initjs()
shap.force_plot(explainer.expected_value, shap_values[0], X_train.iloc[0,:])
y_train_prob = model.predict([X_train[c] for c in X_train.columns])
print('shap base value:',explainer.expected_value, ' mean of prediction: ',np.mean(y_train_prob[:10]))#上面的Explainer用了多少数据去评估,base value就是多少个样本的预测值的平均
shap base value: 0.1138872653245926 mean of prediction: 0.113887265
shap.summary_plot(shap_values, X_train,plot_type="bar")
shap.dependence_plot('Age', shap_values, X_train,interaction_index=None)
shap.dependence_plot('Age', shap_values, X_train,interaction_index='Pclass')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。