当前位置:   article > 正文

python可视化——热力图heatmap seaborn库_python sns.heatmap

python sns.heatmap

一、基本介绍

Seaborn是基于matplotlib的Python可视化库。它提供了一个高级界面来绘制有吸引力的统计图形。Seaborn其实是在matplotlib的基础上进行了更高级的API封装,从而使得作图更加容易,不需要经过大量的调整就能使你的图变得精致。热力图在实际中常用于展示一组变量的相关系数矩阵,在展示列联表的数据分布上也有较大的用途,通过热力图我们可以非常直观地感受到数值大小的差异状况,方便我们知道模型是如何看待图片的,也方便我们检测出模型的偏向(bias)。heatmap的API如下所示:

seaborn.heatmap(data, vmin=None, vmax=None,cmap=None, center=None, robust=False, annot=None, fmt=’.2g’, annot_kws=None,linewidths=0, linecolor=’white’, cbar=True, cbar_kws=None, cbar_ax=None,square=False, xticklabels=’auto’, yticklabels=’auto’, mask=None, ax=None,**kwargs)

二、参数介绍 

2.1热力图输入数据参数:data

data:矩阵数据集,可以是numpy的数组(array),也可以是pandas的DataFrame。如果是DataFrame,则df的index/column信息会分别对应到heatmap的columns和rows,即df.index是热力图的行标,df.columns是热力图的列标

2.2热力图矩阵块颜色参数:vmin,vmax,cmap,center,robust

vmax,vmin:分别是热力图的颜色取值最大和最小范围,默认是根据data数据表里的取值确定
cmap:从数字到色彩空间的映射,取值是matplotlib包里的colormap名称或颜色对象,或者表示颜色的列表;改参数默认值:根据center参数设定
center:数据表取值有差异时,设置热力图的色彩中心对齐值;通过设置center值,可以调整生成的图像颜色的整体深浅;设置center数据时,如果有数据溢出,则手动设置的vmax、vmin会自动改变
robust:默认取值False;如果是False,且没设定vmin和vmax的值,热力图的颜色映射范围根据具有鲁棒性的分位数设定,而不是用极值设定

2.3热力图矩阵块注释参数:annot,fmt,annot_kws

annot(annotate的缩写):默认取值False;如果是True,在热力图每个方格写入数据;如果是矩阵,在热力图每个方格写入该矩阵对应位置数据
fmt:字符串格式代码,矩阵上标识数字的数据格式,比如保留小数点后几位数字
annot_kws:默认取值False;如果是True,设置热力图矩阵上数字的大小颜色字体,matplotlib包text类下的字体设置;官方文档:

2.4热力图矩阵块之间间隔及间隔线参数:linewidths,linecolor

linewidths:定义热力图里“表示两两特征关系的矩阵小块”之间的间隔大小
linecolor:切分热力图上每个矩阵小块的线的颜色,默认值是’white’

2.5热力图颜色刻度条参数:cbar,cbar_kws,cbar_ax

cbar:是否在热力图侧边绘制颜色刻度条,默认值是True
cbar_kws:热力图侧边绘制颜色刻度条时,相关字体设置,默认值是None
cbar_ax:热力图侧边绘制颜色刻度条时,刻度条位置设置,默认值是None

2.6设置热力图矩阵小块形状:square

square:默认值是False

2.7控制标签名的输出:xticklabels,yticklabels

xticklabels, yticklabels:xticklabels控制每列标签名的输出;yticklabels控制每行标签名的输出。默认值是auto。如果是True,则以DataFrame的列名作为标签名。如果是False,则不添加行标签名。如果是列表,则标签名改为列表中给的内容。如果是整数K,则在图上每隔K个标签进行一次标注。 如果是auto,则自动选择标签的标注间距,将标签名不重叠的部分(或全部)输出

2.8其他参数mask,ax,**kwargs 

mask:控制某个矩阵块是否显示出来。默认值是None。如果是布尔型的DataFrame,则将DataFrame里True的位置用白色覆盖掉
ax:设置作图的坐标轴,一般画多个子图时需要修改不同的子图的该值
**kwargs:All other keyword arguments are passed to ax.pcolormesh

三、具体示例 

3.1简单例子

1.为防止乱码,添加utf-8和unicode_escape,并在中文路径处指定gbk编码格式

# coding : utf-8

# coding : unicode-escape

2.导入各种库

  1. import numpy as np
  2. import pandas as pd
  3. import matplotlib.pyplot as plt
  4. import seaborn as sns

 完整代码

  1. # coding : utf-8
  2. # coding : unicode-escape
  3. import numpy as np
  4. import pandas as pd
  5. import matplotlib.pyplot as plt
  6. import seaborn as sns
  7. # Plot a heatmap for a numpy array:
  8. uniform_data = np.random.rand(12,10)
  9. ax = sns.heatmap(uniform_data)
  10. plt.show()

运行结果:(默认参数)

 

添加修改各种参数:

1.指定颜色的数值范围:vmax,vmin 

  1. uniform_data = np.random.rand(12,10)
  2. ax = sns.heatmap(uniform_data, vmin=0, vmax=1)
  3. plt.show()

输出结果: 

 

2.确定颜色映射的中心:center

  1. # Plot a heatmap for data centered on 0 with a diverging colormap:
  2. normal_data = np.random.randn(10, 12)
  3. ax = sns.heatmap(normal_data, center=0)
  4. plt.show()

输出结果:

  

3.2实际例子

3.2.1直接读取csv,画热力图

  1. # coding : utf-8
  2. # coding : unicode-escape
  3. import numpy as np
  4. import pandas as pd
  5. import matplotlib.pyplot as plt
  6. import seaborn as sns
  7. df1 = pd.read_csv('D:/myCode/spark/spark_ML/df1.csv')
  8. pd.set_option('max_columns', 100) #显示最多列数,超出该数以省略号表示
  9. # 打印数据信息前5行的具体内容
  10. print(df1.head())
  11. # 画热力图
  12. plt.figure(figsize=(16, 9),dpi=300)
  13. corr = df1.corr()
  14. sns.heatmap(corr,cmap='Reds')
  15. plt.show()

 输出和打印结果:

 3.2.2在矩阵色块中添加注释:annot

  1. # 画热力图
  2. plt.figure(figsize=(16, 9),dpi=300)
  3. corr = df1.corr()
  4. sns.heatmap(corr,cmap='Reds',annot=True)
  5. plt.show()

 3.2.3在色块间添加分割线和指定分割线宽度和颜色:linewidths,linecolor

  1. # 画热力图
  2. plt.figure(figsize=(16, 9),dpi=300)
  3. corr = df1.corr()
  4. sns.heatmap(corr,cmap='Reds',annot=True,linewidths=0.05,linecolor="red")
  5. plt.show()

 3.2.4改变色带:cmap

在这之前,先补充下,如前面几个图,横纵坐标的中文都没有显示,只有方框,所以只需要在代码开头引入这三行代码即可

  1. from pylab import *
  2. mpl.rcParams['font.sans-serif'] = ['SimHei']
  3. mpl.rcParams['axes.unicode_minus'] = False
  1. # 画热力图
  2. plt.figure(figsize=(16, 9),dpi=300)
  3. corr = df1.corr()
  4. # 之前的上几个图的cmap='Reds'
  5. sns.heatmap(corr,cmap='YlGnBu',annot=True,linewidths=0.05,linecolor="red")
  6. plt.show()

 3.2.5设置坐标字体、方向

  1. # 画热力图
  2. plt.figure(figsize=(16, 9),dpi=300)
  3. corr = df1.corr()
  4. ax = sns.heatmap(corr,cmap='Reds',annot=True,linewidths=0.05,linecolor="red")
  5. # 设置坐标字体
  6. ax.set_yticklabels(ax.get_yticklabels(), rotation=45)
  7. ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
  8. plt.show()


 2021年4月3日更新

发现有时候热力图会显示一部分:绘制热力图最上面一行和最下面一行都只显示一半。

解决方法:

原因:据说是所用版本matplotlib的bug
解决方法:
1、更改matplotlib版本

   安装其他版本,我回退到了3.0.3的版本

   conda install matplotlib=3.0.3

2、加入以下代码即可:

ax = sns.heatmap(...);
bottom, top = ax.get_ylim()
ax.set_ylim(bottom + 0.5, top - 0.5)

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/153550
推荐阅读
相关标签
  

闽ICP备14008679号