当前位置:   article > 正文

SHAP的介绍和应用(附代码)

shap

SHAP Tutorial

本文主要介绍:

  • SHAP的原理
  • SHAP的应用方式

SHAP的介绍

SHAP的目标就是通过计算每个样本中每一个特征对prediction的贡献, 来对模型结果做解释。在合作博弈论的启发下SHAP构建一个加性的解释模型,所有的特征都视为“贡献者”。对于每个预测样本,模型都产生一个预测值,SHAP值就是该样本中每个特征所分配到的数值。

设第 i i i个样本为 x i x_i xi,第 i i i个样本的第 j j j个特征为 x i j x_i^j xij,模型对该样本的预测值为 y i y_i yi,整个模型的基线(通常是所有样本的目标变量的均值)为 y b a s e y_{base} ybase,那么SHAP值服从以下等式:

y i = y b a s e + f ( x i 1 ) + f ( x i 2 ) + . . . + f ( x i j ) y_i = y_{base} + f(x_i^1)+f(x_i^2)+...+f(x_i^j) yi=ybase+f(xi1)+f(xi2)+...+f(xij)

其中 f ( x i j ) f(x_i^j) f(xij) x i j x_i^j xij的SHAP值。直观上看, f ( x i 1 ) f(x_i^1) f(xi1)就是第 i i i个样本中第1个特征对最终预测值 y i y_i yi的贡献值,当 f ( x i 1 ) > 0 f(x_i^1)>0 f(xi1)>0,说明该特征提升了预测值,也正向作用;反之,说明该特征使得预测值降低,有反作用。上式是也是通过拟合的方式获得 f ( x i j ) f(x_i^j) f(xij)的值。具体的是如何拟合的这里不做要求。

以树模型TreeSHAP为例子说明SHAP值的用法

这个教材以Titantic数据为例子进行说明
下面简单看下titantic的特征:

PassengerId 乘客ID

Survived 获救与否(0死亡,1生存)

Pclass 乘客等级(1/2/3等舱位)

Name 乘客姓名

Sex 性别

Age 年龄

SibSp 堂兄弟/妹个数

Parch 父母与小孩个数

Ticket 船票信息

Fare 票价

Cabin 客舱

Embarked 登船港口

这是一个二分类问题,其中Survived是要预测的值, Pclass,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked为特征

我们将以集成树模型(LightGBM为例来讲解)

注:我们尽量简化这个问题,方便大家理解shap是如何应用的。

#拉取数据
!git clone https://github.com/pangpang97/shap_tutorial
  • 1
  • 2
Cloning into 'shap_tutorial'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects: 100% (12/12), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 12 (delta 1), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (12/12), done.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
#引入需要的包
import pandas as pd
import numpy as np
all_data = pd.read_csv('shap_tutorial/train_titantic.csv')
all_data.info()
  • 1
  • 2
  • 3
  • 4
  • 5
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

简单的数据处理

#删去暂时不要的列
all_data.drop(['PassengerId','Name'], axis=1, inplace=True)
#填充缺失值
all_data['Age'].fillna(0, inplace=True)
all_data['Cabin'].fillna('UNK', inplace=True)
all_data['Embarked'].fillna('UNK', inplace=True)
all_data.info()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 10 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   Survived  891 non-null    int64  
 1   Pclass    891 non-null    int64  
 2   Sex       891 non-null    object 
 3   Age       891 non-null    float64
 4   SibSp     891 non-null    int64  
 5   Parch     891 non-null    int64  
 6   Ticket    891 non-null    object 
 7   Fare      891 non-null    float64
 8   Cabin     891 non-null    object 
 9   Embarked  891 non-null    object 
dtypes: float64(2), int64(4), object(4)
memory usage: 69.7+ KB
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
#把Sex,Ticket,Cabin和Embarked做一个lebel encoding
from sklearn.preprocessing import LabelEncoder
cate_cols = ['Sex','Ticket','Cabin','Embarked']
for cc in cate_cols:
    enc = LabelEncoder()
    all_data[cc] = enc.fit_transform(all_data[cc])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

LightGBM训练

#拆分下训练集和测试集
from sklearn.model_selection import train_test_split
train, test = train_test_split(all_data,test_size=0.2)

#利用LightGBM训练模型
import lightgbm as lgb
from sklearn import metrics
params = {'objective': 'binary',
          'metric': 'binary_logloss',
          'num_round': 80,
          'verbose':1
              }
num_round = params.pop('num_round',1000)
xtrain = lgb.Dataset(train.drop(columns=['Survived']), train['Survived'],free_raw_data=False)
xeval = lgb.Dataset(test.drop(columns=['Survived']), test['Survived'],free_raw_data=False)
evallist = [xtrain, xeval]
clf = lgb.train(params, xtrain, num_round, valid_sets=evallist)
ytrain = np.where(clf.predict(train.drop(columns=['Survived']))>=0.5, 1,0)
ytest = np.where(clf.predict(test.drop(columns=['Survived']))>=0.5, 1,0)
print("train classification report")
print(metrics.classification_report(train['Survived'], ytrain))
print('*'*60)
print("test classification report")
print(metrics.classification_report(test['Survived'], ytest))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
[1]	training's binary_logloss: 0.616752	valid_1's binary_logloss: 0.628188
[2]	training's binary_logloss: 0.577503	valid_1's binary_logloss: 0.59757
[3]	training's binary_logloss: 0.545282	valid_1's binary_logloss: 0.573299
[4]	training's binary_logloss: 0.517894	valid_1's binary_logloss: 0.55145
[5]	training's binary_logloss: 0.49365	valid_1's binary_logloss: 0.53399
[6]	training's binary_logloss: 0.472167	valid_1's binary_logloss: 0.518513
[7]	training's binary_logloss: 0.453641	valid_1's binary_logloss: 0.506068
[8]	training's binary_logloss: 0.437905	valid_1's binary_logloss: 0.494044
[9]	training's binary_logloss: 0.424174	valid_1's binary_logloss: 0.485643
[10]	training's binary_logloss: 0.410632	valid_1's binary_logloss: 0.477679
[11]	training's binary_logloss: 0.398761	valid_1's binary_logloss: 0.470235
[12]	training's binary_logloss: 0.388724	valid_1's binary_logloss: 0.464987
[13]	training's binary_logloss: 0.379662	valid_1's binary_logloss: 0.460591
[14]	training's binary_logloss: 0.370902	valid_1's binary_logloss: 0.457459
[15]	training's binary_logloss: 0.363174	valid_1's binary_logloss: 0.455261
[16]	training's binary_logloss: 0.356024	valid_1's binary_logloss: 0.451337
[17]	training's binary_logloss: 0.348805	valid_1's binary_logloss: 0.447478
[18]	training's binary_logloss: 0.340424	valid_1's binary_logloss: 0.444287
[19]	training's binary_logloss: 0.3332	valid_1's binary_logloss: 0.439115
[20]	training's binary_logloss: 0.326891	valid_1's binary_logloss: 0.436567
[21]	training's binary_logloss: 0.321057	valid_1's binary_logloss: 0.434214
[22]	training's binary_logloss: 0.315413	valid_1's binary_logloss: 0.433245
[23]	training's binary_logloss: 0.310579	valid_1's binary_logloss: 0.434182
[24]	training's binary_logloss: 0.305809	valid_1's binary_logloss: 0.43319
[25]	training's binary_logloss: 0.301533	valid_1's binary_logloss: 0.432078
[26]	training's binary_logloss: 0.295784	valid_1's binary_logloss: 0.430593
[27]	training's binary_logloss: 0.290461	valid_1's binary_logloss: 0.428543
[28]	training's binary_logloss: 0.285622	valid_1's binary_logloss: 0.426985
[29]	training's binary_logloss: 0.279571	valid_1's binary_logloss: 0.426044
[30]	training's binary_logloss: 0.274883	valid_1's binary_logloss: 0.427347
[31]	training's binary_logloss: 0.270582	valid_1's binary_logloss: 0.427791
[32]	training's binary_logloss: 0.265887	valid_1's binary_logloss: 0.428961
[33]	training's binary_logloss: 0.260971	valid_1's binary_logloss: 0.430992
[34]	training's binary_logloss: 0.254612	valid_1's binary_logloss: 0.429185
[35]	training's binary_logloss: 0.250871	valid_1's binary_logloss: 0.429506
[36]	training's binary_logloss: 0.245969	valid_1's binary_logloss: 0.425777
[37]	training's binary_logloss: 0.24145	valid_1's binary_logloss: 0.425065
[38]	training's binary_logloss: 0.237223	valid_1's binary_logloss: 0.423375
[39]	training's binary_logloss: 0.233457	valid_1's binary_logloss: 0.42279
[40]	training's binary_logloss: 0.229837	valid_1's binary_logloss: 0.421586
[41]	training's binary_logloss: 0.2258	valid_1's binary_logloss: 0.419546
[42]	training's binary_logloss: 0.222053	valid_1's binary_logloss: 0.420543
[43]	training's binary_logloss: 0.218142	valid_1's binary_logloss: 0.419648
[44]	training's binary_logloss: 0.214406	valid_1's binary_logloss: 0.417376
[45]	training's binary_logloss: 0.21062	valid_1's binary_logloss: 0.417219
[46]	training's binary_logloss: 0.207693	valid_1's binary_logloss: 0.418382
[47]	training's binary_logloss: 0.204459	valid_1's binary_logloss: 0.420574
[48]	training's binary_logloss: 0.201661	valid_1's binary_logloss: 0.420458
[49]	training's binary_logloss: 0.198652	valid_1's binary_logloss: 0.420256
[50]	training's binary_logloss: 0.195849	valid_1's binary_logloss: 0.41788
[51]	training's binary_logloss: 0.192828	valid_1's binary_logloss: 0.419856
[52]	training's binary_logloss: 0.189455	valid_1's binary_logloss: 0.419239
[53]	training's binary_logloss: 0.186862	valid_1's binary_logloss: 0.418061
[54]	training's binary_logloss: 0.184144	valid_1's binary_logloss: 0.420203
[55]	training's binary_logloss: 0.18186	valid_1's binary_logloss: 0.419781
[56]	training's binary_logloss: 0.179336	valid_1's binary_logloss: 0.418251
[57]	training's binary_logloss: 0.176953	valid_1's binary_logloss: 0.418373
[58]	training's binary_logloss: 0.174188	valid_1's binary_logloss: 0.421177
[59]	training's binary_logloss: 0.171624	valid_1's binary_logloss: 0.422029
[60]	training's binary_logloss: 0.169554	valid_1's binary_logloss: 0.421043
[61]	training's binary_logloss: 0.167043	valid_1's binary_logloss: 0.420784
[62]	training's binary_logloss: 0.164732	valid_1's binary_logloss: 0.421378
[63]	training's binary_logloss: 0.162674	valid_1's binary_logloss: 0.421023
[64]	training's binary_logloss: 0.161026	valid_1's binary_logloss: 0.422021
[65]	training's binary_logloss: 0.159152	valid_1's binary_logloss: 0.423376
[66]	training's binary_logloss: 0.157098	valid_1's binary_logloss: 0.423112
[67]	training's binary_logloss: 0.154923	valid_1's binary_logloss: 0.424078
[68]	training's binary_logloss: 0.152742	valid_1's binary_logloss: 0.423512
[69]	training's binary_logloss: 0.150701	valid_1's binary_logloss: 0.422217
[70]	training's binary_logloss: 0.149009	valid_1's binary_logloss: 0.420999
[71]	training's binary_logloss: 0.147453	valid_1's binary_logloss: 0.421606
[72]	training's binary_logloss: 0.145442	valid_1's binary_logloss: 0.4223
[73]	training's binary_logloss: 0.143639	valid_1's binary_logloss: 0.421754
[74]	training's binary_logloss: 0.14174	valid_1's binary_logloss: 0.421546
[75]	training's binary_logloss: 0.13988	valid_1's binary_logloss: 0.423675
[76]	training's binary_logloss: 0.138371	valid_1's binary_logloss: 0.423968
[77]	training's binary_logloss: 0.136794	valid_1's binary_logloss: 0.424537
[78]	training's binary_logloss: 0.134842	valid_1's binary_logloss: 0.425823
[79]	training's binary_logloss: 0.132775	valid_1's binary_logloss: 0.424314
[80]	training's binary_logloss: 0.131154	valid_1's binary_logloss: 0.42698
train classification report
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       439
           1       0.99      0.95      0.97       273

    accuracy                           0.98       712
   macro avg       0.98      0.97      0.97       712
weighted avg       0.98      0.98      0.98       712

************************************************************
test classification report
              precision    recall  f1-score   support

           0       0.83      0.93      0.88       110
           1       0.86      0.70      0.77        69

    accuracy                           0.84       179
   macro avg       0.84      0.81      0.82       179
weighted avg       0.84      0.84      0.83       179
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100

SHAP值的常规用法

!pip install shap
import warnings
warnings.filterwarnings("ignore")
import shap
shap.initjs()
  • 1
  • 2
  • 3
  • 4
  • 5
Collecting shap
  Downloading shap-0.40.0-cp37-cp37m-manylinux2010_x86_64.whl (564 kB)
[K     |████████████████████████████████| 564 kB 3.6 MB/s 
[?25hRequirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from shap) (1.4.1)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from shap) (1.0.1)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap) (1.3.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from shap) (1.19.5)
Collecting slicer==0.0.7
  Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap) (21.3)
Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap) (4.62.3)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from shap) (1.1.5)
Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap) (0.51.2)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>20.9->shap) (3.0.6)
Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap) (0.34.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba->shap) (57.4.0)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2018.9)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (3.0.0)
Installing collected packages: slicer, shap
Successfully installed shap-0.40.0 slicer-0.0.7
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

在这里插入图片描述

#如果数据量大,这个运行的会非常慢
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(train.drop(columns=['Survived'])) #获取shap value
np.array(shap_values).shape #看一下shap value的dim
  • 1
  • 2
  • 3
  • 4
(2, 712, 9)
  • 1

shap值是一个三维的数据。一个样本有两个shap值。

第一个维度控制的是选择哪个类别的shap值,第一个维度是0表示0(negative)类,第一个维度是1表示1(positive)类。

后面两个维度就是每个样本和每个特征的shap值。

(shap_values[0] == -1* shap_values[1]).all()
  • 1
True
  • 1

可以看到0类的shap值和1类的shap值是相反的

单个样本的shap值

可能的应用场景:

  • 分析bad case:分析分错样本和分对的样本差异
  • 业务向的分析:例如,在流失场景中分析用户流失的具体原因

我们以看第一个样本的正类的shap值为例

train.drop(columns=['Survived']).iloc[0].T
  • 1
Pclass        3.0000
Sex           1.0000
Age          20.0000
SibSp         1.0000
Parch         1.0000
Ticket      187.0000
Fare         15.7417
Cabin       147.0000
Embarked      0.0000
Name: 622, dtype: float64
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
#查看单个样本的特征贡献的第一种方法
shap.initjs()# colab需要在每个cell上运行这个命令,如果你是jupyter notebook或jupyter lab可以把这行注释掉
shap.plots.force(explainer.expected_value[1],shap_values[1][0],train.drop(columns=['Survived']).iloc[0])
  • 1
  • 2
  • 3

在这里插入图片描述

上面的图表示的是: 对一个样本来说,各个特征是如何把预测值从base value推到f(x)(最终的输出值)的。说明了各个特征的贡献。红色的表示向正向(往右)推动预测值;蓝色的表示向正负(往左)推动预测值的。每个特征图块的大小表示shap值的大小。base value是指的模型对训练集的所有样本的(这个参考资料中都没有明确说明 这个是我试出来的)预测值的均值。对于回归问题就是均值,如果对于分类问题是每个类别的均值。是explainer.expected_value的值。

#看一下base value
explainer.expected_value
  • 1
  • 2
[0.908824118858436, -0.908824118858436]
  • 1

因为是分类问题,所以explainer.expected_value有两个值,第一个是0类预测均值,第二个是1类的均值。 SHAP对于xgboost和lightgbm取的是对数几率转换. 还有个问题需要注意下

y_train_prob = clf.predict(train.drop(columns=['Survived']))
print('shap base value:', explainer.expected_value[1], ' 取log的预测值的均值:',np.log(y_train_prob/ (1 - y_train_prob)).mean())
  • 1
  • 2
shap base value: -0.908824118858436  取log的预测值的均值: -0.9088241188584363
  • 1

可以看到上面的结果,shap的base value和取log的预测值的均值是基本相等的。 如果看的是test的shap 取log的预测值就会和base value有明显差异。

#查看单个样本的特征贡献的第二种方法
shap.plots._waterfall.waterfall_legacy(explainer.expected_value[1], shap_values[1][0],train.drop(columns=['Survived']).iloc[0])
  • 1
  • 2

在这里插入图片描述

这个图的意思跟shap.plots.force()方法是一样的,但是我感觉看起来更加的直观。

这个方法是替代shap.plots.waterfall()shap.plots.waterfall()方法的。 这俩方法会报错。

#查看单个样本的特征贡献的第三种方法
shap.bar_plot(shap_values[1][0],train.drop(columns=['Survived']).iloc[0])
  • 1
  • 2

在这里插入图片描述
这个图跟前面两个是很相似的。但是没有标记base value和f(x)(最终的预测值)

多个样本的shap值

shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values[1],train.drop(columns=['Survived']))#我们看下训练集所有样本的特征贡献情况
  • 1
  • 2

在这里插入图片描述

这个图是训练集的所有样本的特征贡献情况,用鼠标华东可以看每一个样本的情况。

横坐标:可以按照多种指标将样本排序,包括: 样本的相似性(就是把相似的样本排在一起),样本输出结果,原始样本顺序,各个特征值;

纵坐标:可以按照最终的预测值以及各个各种的贡献值排序

模型整体情况(特征重要性)

可能的应用场景:

  • 分析模型的特征,帮助特征筛选和特征工程
  • 模型异常时分析是否有特征穿越
  • 业务分析: 例如,在流失场景中分析,哪些因素是用户流失的主要原因
shap.summary_plot(shap_values, train.drop(columns=['Survived']),plot_type="bar")
  • 1

在这里插入图片描述

这既是利用shap值做的特征重要性的图。 这个是每个样本的0类和1类的shap值分别取绝对值再求平均得到的。

这里分别展示了0类和1类的特征重要性。通过上面的分析知道,0类的shap值和1类的shap值的绝对值是一样的。所以理论上只看一个就行了。

shap.summary_plot(shap_values[1], train.drop(columns=['Survived']),plot_type="bar")
  • 1

在这里插入图片描述

这个是我们经常用到的特征重要性。

这个的计算方式和上面是一样的,只不过只取了1类的shap值。

shap.summary_plot(shap_values[1], train.drop(columns=['Survived']))
  • 1

在这里插入图片描述

这个可以看作是特征重要性的细致版本。这里为每个样本绘制其每个特征的SHAP值,这可以更好地理解整体模式,并允许发现预测异常值。每一行代表一个特征,横坐标为SHAP值。一个点代表一个样本,颜色表示特征值(红色高,蓝色低)。

例如:性别(Sex)为例,Sex=1的shap值要小于Sex=0的shap值。联系到这个问题可以分析出来,Sex=0(即性别是female)更容易活下来(Survivied=1)。同理发现 乘客等级越低(Pclass值越大,shap值越小,越难以活下来。

单个特征与预测结果的关系

可能的应用场景:

  • 分析连续特征与预测的关系(是否有线性和非线性的关系),帮助连续特征离散化
shap.dependence_plot('Age', shap_values[1], train.drop(columns=['Survived']),interaction_index=None)
  • 1

在这里插入图片描述

如图,横坐标是特征(Age)的值,纵坐标是对应的shap值,整体的趋势是先下降后升高的。这说明小孩子和老人更容易活下来。

多个特征与预测结果的关系

可能的应用场景:

  • 分析特征组合与预测的关系,判断两个特征是否有可能组成更好的特征
shap.dependence_plot('Age', shap_values[1], train.drop(columns=['Survived']),interaction_index='Pclass')
  • 1

在这里插入图片描述

如图,横坐标是特征(Age)的值,纵坐标是对应的shap值,颜色的深浅代表Pclass。

从图中可以看出来,在乘客级别较低的地方,年纪大的人更容易活下来。

其他

SHAP不仅仅有TreeExplainer,还有其他的Explainer:

  • TreeExplainer : 适合XGBoost, LightGBM, CatBoost以及scikit-learn models里的树模型
  • DeepExplainer (DEEP SHAP) : 适合TensorFlow and Keras models的模型
  • GradientExplainer : 适合TensorFlow and Keras models的模型
  • KernelExplainer: 适合任何的model

深度学习使用SHAP

import pandas as pd
import numpy as np
all_data = pd.read_csv('shap_tutorial/train_titantic.csv')

#删去暂时不要的列
all_data.drop(['PassengerId','Name'], axis=1, inplace=True)
#填充缺失值
all_data['Age'].fillna(0, inplace=True)
all_data['Cabin'].fillna('UNK', inplace=True)
all_data['Embarked'].fillna('UNK', inplace=True)

#把Sex,Ticket,Cabin和Embarked做一个lebel encoding
from sklearn.preprocessing import LabelEncoder
cate_cols = ['Sex','Ticket','Cabin','Embarked']
cate_features = {}
cate_cnt ={}
for cc in cate_cols:
    enc = LabelEncoder()
    all_data[cc] = enc.fit_transform(all_data[cc])
    cate_cnt[cc] = all_data[cc].max()

cate_cnt['Parch'] = all_data['Parch'].max()
cate_cnt['Pclass'] = all_data['Pclass'].max()
cate_cnt['SibSp'] = all_data['SibSp'].max()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
cate_cnt
  • 1
{'Cabin': 147,
 'Embarked': 3,
 'Parch': 6,
 'Pclass': 3,
 'Sex': 1,
 'SibSp': 8,
 'Ticket': 680}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
#拆分下训练集和测试集
use_cols = ['Cabin','Embarked','Parch','Pclass','Sex','SibSp','Ticket']
from sklearn.model_selection import train_test_split
y = all_data['Survived']
X = all_data.copy()
X.drop(columns=['Survived'], inplace=True)
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras.layers import *
from tensorflow.keras import Input
from tensorflow.keras import Model

def mlp(train, use_cols):
    input_list = []
    emb_list = []
    for c in train.columns:
        input_list.append(Input(shape=(1,),name=c))
        if c in use_cols:
            #需要做embedding
            emb_feature = Flatten()(Embedding(input_dim=train[c].max()+1, output_dim=4)(input_list[-1]))
        else:
            emb_feature = input_list[-1]
        emb_list.append(emb_feature)
    #concate all input
    all_in = Concatenate()(emb_list)
    hidden = Dense(32,activation='relu')(all_in)
    hidden = Dense(16, activation='relu')(hidden)
    hidden = Dense(8, activation='relu')(hidden)
    y = Dense(1, activation='sigmoid')(hidden)
    #model
    model = Model(inputs=input_list, outputs=[y])
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
2.7.0
  • 1
model = mlp(X, use_cols)
model.compile(loss='binary_crossentropy', optimizer='adam',metrics=['acc'])
model.fit([X_train[c] for c in X_train.columns], y_train, epochs=18, batch_size=32, validation_data=([X_test[c] for c in X_test.columns], y_test))
  • 1
  • 2
  • 3
Epoch 1/18
23/23 [==============================] - 4s 26ms/step - loss: 1.4307 - acc: 0.4789 - val_loss: 0.7328 - val_acc: 0.6536
Epoch 2/18
23/23 [==============================] - 0s 8ms/step - loss: 0.6792 - acc: 0.6362 - val_loss: 0.6359 - val_acc: 0.6704
Epoch 3/18
23/23 [==============================] - 0s 9ms/step - loss: 0.6378 - acc: 0.6657 - val_loss: 0.6211 - val_acc: 0.6760
Epoch 4/18
23/23 [==============================] - 0s 11ms/step - loss: 0.6203 - acc: 0.6924 - val_loss: 0.6124 - val_acc: 0.6927
Epoch 5/18
23/23 [==============================] - 0s 10ms/step - loss: 0.6074 - acc: 0.6882 - val_loss: 0.6106 - val_acc: 0.7039
Epoch 6/18
23/23 [==============================] - 0s 12ms/step - loss: 0.5984 - acc: 0.6924 - val_loss: 0.6007 - val_acc: 0.7151
Epoch 7/18
23/23 [==============================] - 0s 12ms/step - loss: 0.5701 - acc: 0.7233 - val_loss: 0.5816 - val_acc: 0.7207
Epoch 8/18
23/23 [==============================] - 0s 14ms/step - loss: 0.5453 - acc: 0.7374 - val_loss: 0.5581 - val_acc: 0.7542
Epoch 9/18
23/23 [==============================] - 0s 9ms/step - loss: 0.5228 - acc: 0.7556 - val_loss: 0.5418 - val_acc: 0.7039
Epoch 10/18
23/23 [==============================] - 0s 7ms/step - loss: 0.4760 - acc: 0.7921 - val_loss: 0.5150 - val_acc: 0.7598
Epoch 11/18
23/23 [==============================] - 0s 7ms/step - loss: 0.4496 - acc: 0.8076 - val_loss: 0.4834 - val_acc: 0.8156
Epoch 12/18
23/23 [==============================] - 0s 10ms/step - loss: 0.4187 - acc: 0.8371 - val_loss: 0.4999 - val_acc: 0.7821
Epoch 13/18
23/23 [==============================] - 0s 12ms/step - loss: 0.3951 - acc: 0.8581 - val_loss: 0.4879 - val_acc: 0.8045
Epoch 14/18
23/23 [==============================] - 0s 10ms/step - loss: 0.3624 - acc: 0.8736 - val_loss: 0.4686 - val_acc: 0.8101
Epoch 15/18
23/23 [==============================] - 0s 13ms/step - loss: 0.3317 - acc: 0.8862 - val_loss: 0.4619 - val_acc: 0.8436
Epoch 16/18
23/23 [==============================] - 0s 11ms/step - loss: 0.2931 - acc: 0.9059 - val_loss: 0.4536 - val_acc: 0.8212
Epoch 17/18
23/23 [==============================] - 0s 9ms/step - loss: 0.2752 - acc: 0.9157 - val_loss: 0.4332 - val_acc: 0.8492
Epoch 18/18
23/23 [==============================] - 0s 11ms/step - loss: 0.2314 - acc: 0.9382 - val_loss: 0.4477 - val_acc: 0.8492

<keras.callbacks.History at 0x7f47edcbd9d0>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
from sklearn import metrics
ytrain_pred = np.where(model.predict([X_train[c] for c in X_train.columns])>=0.5, 1,0)
ytest_pred = np.where(model.predict([X_test[c] for c in X_test.columns])>=0.5, 1,0)
print("train classification report")
print(metrics.classification_report(y_train, ytrain_pred))
print('*'*60)
print("test classification report")
print(metrics.classification_report(y_test, ytest_pred))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
train classification report
              precision    recall  f1-score   support

           0       0.91      0.99      0.95       437
           1       0.97      0.84      0.90       275

    accuracy                           0.93       712
   macro avg       0.94      0.91      0.92       712
weighted avg       0.93      0.93      0.93       712

************************************************************
test classification report
              precision    recall  f1-score   support

           0       0.85      0.93      0.89       112
           1       0.86      0.72      0.78        67

    accuracy                           0.85       179
   macro avg       0.85      0.82      0.83       179
weighted avg       0.85      0.85      0.85       179
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
!pip install shap
import warnings
warnings.filterwarnings("ignore")
import shap
shap.initjs()
  • 1
  • 2
  • 3
  • 4
  • 5
Requirement already satisfied: shap in /usr/local/lib/python3.7/dist-packages (0.40.0)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from shap) (1.4.1)
Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap) (21.3)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from shap) (1.1.5)
Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap) (0.51.2)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from shap) (1.0.1)
Requirement already satisfied: slicer==0.0.7 in /usr/local/lib/python3.7/dist-packages (from shap) (0.0.7)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from shap) (1.19.5)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap) (1.3.0)
Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap) (4.62.3)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>20.9->shap) (3.0.6)
Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap) (0.34.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba->shap) (57.4.0)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2018.9)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (3.0.0)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

在这里插入图片描述

def f(X):
    return model.predict([X[:,i] for i in range(X.shape[1])]).flatten()
  • 1
  • 2
#这个Explainer非常的慢,大概需要运行3分钟。如果数据量更大,时间会更长。这个比TreeExplainer慢很多。
explainer = shap.KernelExplainer(f, X_train.iloc[:10]) #用十个样本去评估,如果用所有的数据,会很慢
shap_values = explainer.shap_values(X_train) #获取shap value
np.array(shap_values).shape #看一下shap value的dim
  • 1
  • 2
  • 3
  • 4
  100%|          | 0/712 [00:00<?, ?it/s]
(712, 9)
  • 1
  • 2
print(X_train.iloc[0,:])
print('Survived ', y_train.iloc[0])
  • 1
  • 2
Pclass        3.0000
Sex           1.0000
Age          29.0000
SibSp         1.0000
Parch         0.0000
Ticket      315.0000
Fare          7.0458
Cabin       147.0000
Embarked      2.0000
Name: 477, dtype: float64
Survived  0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

注意这里是2维的

shap.initjs()
shap.force_plot(explainer.expected_value, shap_values[0], X_train.iloc[0,:])
  • 1
  • 2

在这里插入图片描述

y_train_prob = model.predict([X_train[c] for c in X_train.columns])
print('shap base value:',explainer.expected_value, ' mean of prediction: ',np.mean(y_train_prob[:10]))#上面的Explainer用了多少数据去评估,base value就是多少个样本的预测值的平均
  • 1
  • 2
shap base value: 0.1138872653245926  mean of prediction:  0.113887265
  • 1
shap.summary_plot(shap_values, X_train,plot_type="bar")
  • 1

在这里插入图片描述

shap.dependence_plot('Age', shap_values, X_train,interaction_index=None)
  • 1

在这里插入图片描述

shap.dependence_plot('Age', shap_values, X_train,interaction_index='Pclass')
  • 1

在这里插入图片描述

Reference

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/285573
推荐阅读
相关标签
  

闽ICP备14008679号