当前位置:   article > 正文

seaborn绘制热力图_seaborn 热力图

seaborn 热力图

目录

1、普通绘制热力图

2、坐标轴标签太多,自定义标签显示

3、不显示热图的网格

4、自定义颜色条的距离、标签

5、显示对角线的值

6、一些参数


1、普通绘制热力图

  1. # -*- coding:utf-8 _*-
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. # 创建数据
  6. data = np.random.random((7,12))
  7. # 计算相关性
  8. corr = np.corrcoef(data)
  9. # 设置需要显示的标签
  10. bands_wavelength = ["400","500","600","700","800","900","1000"]
  11. mask = np.zeros_like(corr,dtype=np.bool_)
  12. mask[np.tril_indices_from(mask)] = True
  13. cmap = sns.diverging_palette(220,10,as_cmap=True)
  14. corr = np.flip(corr, axis=0)
  15. mask = np.flip(mask, axis=0)
  16. ax = sns.heatmap(corr,mask=mask.T,cmap=cmap,square=True,linewidths=0.5,
  17. vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
  18. xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
  19. # plt.title("平方数", fontsize=24) # 设置标题
  20. # plt.xlabel("值", fontsize=14) # 设置x标题
  21. # plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
  22. # plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
  23. plt.show()

2、坐标轴标签太多,自定义标签显示

  1. # -*- coding:utf-8 _*-
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. # 创建数据
  6. data = np.random.random((100,12))
  7. # 计算相关性
  8. corr = np.corrcoef(data)
  9. # 设置需要显示的标签
  10. bands_wavelength = ["400","500","600","700","800","900","1000"]
  11. mask = np.zeros_like(corr,dtype=np.bool_)
  12. mask[np.tril_indices_from(mask)] = True
  13. cmap = sns.diverging_palette(220,10,as_cmap=True)
  14. corr = np.flip(corr, axis=0)
  15. mask = np.flip(mask, axis=0)
  16. ax = sns.heatmap(corr,mask=mask.T,cmap=cmap,square=True,linewidths=0.5,
  17. vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
  18. xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
  19. dx = data.shape[0]/(len(bands_wavelength)-1)
  20. ax.set_xticks([dx*i for i in range(len(bands_wavelength))])
  21. ax.set_yticks([dx*i for i in range(len(bands_wavelength))])
  22. ax.set_xticklabels(bands_wavelength)
  23. ax.set_yticklabels(bands_wavelength[::-1])
  24. # plt.title("平方数", fontsize=24) # 设置标题
  25. # plt.xlabel("值", fontsize=14) # 设置x标题
  26. # plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
  27. # plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
  28. plt.show()

3、不显示热图的网格

  1. # -*- coding:utf-8 _*-
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. # 创建数据
  6. data = np.random.random((100,12))
  7. # 计算相关性
  8. corr = np.corrcoef(data)
  9. # 设置需要显示的标签
  10. bands_wavelength = ["400","500","600","700","800","900","1000"]
  11. mask = np.zeros_like(corr,dtype=np.bool_)
  12. mask[np.tril_indices_from(mask)] = True
  13. cmap = sns.diverging_palette(220,10,as_cmap=True)
  14. corr = np.flip(corr, axis=0)
  15. mask = np.flip(mask, axis=0)
  16. # 修改linewidths为0即可
  17. ax = sns.heatmap(corr,mask=mask.T,cmap=cmap,square=True,linewidths=0.,
  18. vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
  19. xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
  20. dx = data.shape[0]/(len(bands_wavelength)-1)
  21. ax.set_xticks([dx*i for i in range(len(bands_wavelength))])
  22. ax.set_yticks([dx*i for i in range(len(bands_wavelength))])
  23. ax.set_xticklabels(bands_wavelength)
  24. ax.set_yticklabels(bands_wavelength[::-1])
  25. # plt.title("平方数", fontsize=24) # 设置标题
  26. # plt.xlabel("值", fontsize=14) # 设置x标题
  27. # plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
  28. # plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
  29. plt.show()

4、自定义颜色条的距离、标签

方法1,推荐:

  1. # -*- coding:utf-8 _*-
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. import matplotlib.ticker as tkr
  6. # 创建数据
  7. data = np.random.random((100, 12))
  8. # 计算相关性
  9. corr = np.corrcoef(data)
  10. # 设置需要显示的标签
  11. bands_wavelength = ["400", "500", "600", "700", "800", "900", "1000"]
  12. mask = np.zeros_like(corr, dtype=np.bool_)
  13. mask[np.tril_indices_from(mask)] = True
  14. cmap = sns.diverging_palette(220, 10, as_cmap=True)
  15. corr = np.flip(corr, axis=0)
  16. mask = np.flip(mask, axis=0)
  17. # ---------设置颜色条----------
  18. cbar_ticks = [-1, -0.5, 0, 0.5, 1]
  19. formatter = tkr.ScalarFormatter(useMathText=True)
  20. formatter.set_scientific(False)
  21. # ----------------------------
  22. # cbar_kws的参数说明
  23. # 'label': #color bar的名称
  24. # 'ticks':#color bar中刻度值范围和间隔
  25. # 'format':'%.0f',#格式化输出color bar中刻度值
  26. # 'pad':0.15,#color bar与热图之间距离,距离变大热图会被压缩
  27. ax = sns.heatmap(corr, mask=mask.T, cmap=cmap, square=True, linewidths=0.,
  28. vmin=np.min(corr), vmax=np.max(corr), cbar=True, # vmin和vmax是自定义显示颜色的范围
  29. xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1]
  30. ,cbar_kws={"ticks": cbar_ticks, "format": formatter,'label':'A','pad':0.01})
  31. ax.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
  32. # 设置colorbar的刻度字体大小
  33. cax = plt.gcf().axes[-1]
  34. cax.tick_params(labelsize=11)
  35. # 设置colorbar的label文本和字体大小
  36. font1 = {'family':'Times New Roman','size':11, 'color':'#000000'} #热图以及colorbar的字体
  37. cbar = ax.collections[0].colorbar
  38. cbar.set_label('A',fontdict=font1)
  39. # ---------设置自定义刻度标签----------
  40. dx = data.shape[0] / (len(bands_wavelength) - 1)
  41. ax.set_xticks([dx * i for i in range(len(bands_wavelength))])
  42. ax.set_yticks([dx * i for i in range(len(bands_wavelength))])
  43. ax.set_xticklabels(bands_wavelength)
  44. ax.set_yticklabels(bands_wavelength[::-1])
  45. # ----------------------------
  46. # plt.title("平方数", fontsize=24) # 设置标题
  47. # plt.xlabel("值", fontsize=14) # 设置x标题
  48. # plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
  49. # plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
  50. plt.show()

方法2: 

  1. # -*- coding:utf-8 _*-
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. from mpl_toolkits.axes_grid1 import make_axes_locatable
  6. # 创建数据
  7. data = np.random.random((100, 12))
  8. # 计算相关性
  9. corr = np.corrcoef(data)
  10. # 设置需要显示的标签
  11. bands_wavelength = ["400", "500", "600", "700", "800", "900", "1000"]
  12. mask = np.zeros_like(corr, dtype=np.bool_)
  13. mask[np.tril_indices_from(mask)] = True
  14. cmap = sns.diverging_palette(220, 10, as_cmap=True)
  15. corr = np.flip(corr, axis=0)
  16. mask = np.flip(mask, axis=0)
  17. ax = sns.heatmap(corr, mask=mask.T, cmap=cmap, square=True, linewidths=0.,
  18. vmin=np.min(corr), vmax=np.max(corr), cbar=False, # vmin和vmax是自定义显示颜色的范围
  19. xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
  20. # ---------设置颜色条----------
  21. divider = make_axes_locatable(ax)
  22. cax = divider.append_axes("right", size="5%", pad=0.1)
  23. cbar = plt.colorbar(ax.collections[0], cax=cax)
  24. cbar.set_ticks([-0.667, 0, 0.667])
  25. cbar.ax.set_yticklabels(['-1', '0', '1'], size=20)
  26. cbar.ax.tick_params(axis='y', which='major', length=0, pad=15)
  27. cbar.outline.set_edgecolor('black')
  28. cbar.outline.set_linewidth(0.01)
  29. # ----------------------------
  30. # ---------设置自定义刻度标签----------
  31. dx = data.shape[0] / (len(bands_wavelength) - 1)
  32. ax.set_xticks([dx * i for i in range(len(bands_wavelength))])
  33. ax.set_yticks([dx * i for i in range(len(bands_wavelength))])
  34. ax.set_xticklabels(bands_wavelength)
  35. ax.set_yticklabels(bands_wavelength[::-1])
  36. # ----------------------------
  37. # plt.title("平方数", fontsize=24) # 设置标题
  38. # plt.xlabel("值", fontsize=14) # 设置x标题
  39. # plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
  40. # plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
  41. plt.show()

5、显示对角线的值

  1. # -*- coding:utf-8 _*-
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. # 创建数据
  6. data = np.random.random((7,12))
  7. # 计算相关性
  8. corr = np.corrcoef(data)
  9. # 设置需要显示的标签
  10. bands_wavelength = ["400","500","600","700","800","900","1000"]
  11. mask = np.zeros_like(corr,dtype=np.bool_)
  12. mask[np.tril_indices_from(mask)] = True
  13. cmap = sns.diverging_palette(220,10,as_cmap=True)
  14. corr = np.flip(corr, axis=0)
  15. mask = np.flip(mask, axis=1)
  16. ax = sns.heatmap(corr,mask=~mask.T,cmap=cmap,square=True,linewidths=0.5,
  17. vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
  18. xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
  19. # plt.title("平方数", fontsize=24) # 设置标题
  20. # plt.xlabel("值", fontsize=14) # 设置x标题
  21. # plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
  22. # plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
  23. plt.show()

6、一些参数

seaborn.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt='.2g', annot_kws=None, linewidths=0, linecolor='white', cbar=True, cbar_kws=None, cbar_ax=None, square=False, xticklabels='auto', yticklabels='auto', mask=None, ax=None, **kwargs)
  • annot

annotate的缩写,annot默认为False,当annot为True时,在heatmap中每个方格写入数据。

annot_kws,当annot为True时,可设置各个参数,包括大小,颜色,加粗,斜体字等。

  1. import numpy as np
  2. import seaborn as sns
  3. from matplotlib.colors import LogNorm
  4. import matplotlib.ticker as tkr
  5. from matplotlib import pyplot as plt
  6. matrix = np.random.rand(10, 10) / 0.4
  7. vmax = 2
  8. vmin = 0.5
  9. cbar_ticks = [0.5, 0.75, 1, 1.33, 2]
  10. formatter = tkr.ScalarFormatter(useMathText=True)
  11. formatter.set_scientific(False)
  12. log_norm = LogNorm(vmin=vmin, vmax=vmax)
  13. ax_ = sns.heatmap(matrix,square=True, vmax=vmax, vmin=vmin, norm=log_norm,
  14. cbar_kws={"ticks": cbar_ticks, "format": formatter},cmap="jet",
  15. annot=True,annot_kws={'size':6,'weight':'bold', 'color':'w'})
  16. ax_.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
  17. plt.show()

  • fmt

字符串,可选参数。添加注释时要使用的字符串格式代码。

  1. import numpy as np
  2. import seaborn as sns
  3. from matplotlib.colors import LogNorm
  4. import matplotlib.ticker as tkr
  5. from matplotlib import pyplot as plt
  6. matrix = np.random.rand(10, 10) / 0.4
  7. vmax = 2
  8. vmin = 0.5
  9. cbar_ticks = [0.5, 0.75, 1, 1.33, 2]
  10. formatter = tkr.ScalarFormatter(useMathText=True)
  11. formatter.set_scientific(False)
  12. log_norm = LogNorm(vmin=vmin, vmax=vmax)
  13. ax_ = sns.heatmap(matrix,square=True, vmax=vmax, vmin=vmin, norm=log_norm,
  14. cbar_kws={"ticks": cbar_ticks, "format": formatter},cmap="jet",
  15. annot=True,annot_kws={'size':6,'weight':'bold', 'color':'w'},
  16. fmt='.2f')
  17. ax_.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
  18. plt.show()

  • ax

matplotlib Axes,可选参数。多子图时使用,用于指定绘制图的坐标轴(即哪个子图),否则使用当前活动的坐标轴。

  1. import numpy as np
  2. import seaborn as sns
  3. from matplotlib.colors import LogNorm
  4. import matplotlib.ticker as tkr
  5. from matplotlib import pyplot as plt
  6. matrix = np.random.rand(10, 10) / 0.4
  7. vmax = 2
  8. vmin = 0.5
  9. cbar_ticks = [0.5, 0.75, 1, 1.33, 2]
  10. formatter = tkr.ScalarFormatter(useMathText=True)
  11. formatter.set_scientific(False)
  12. log_norm = LogNorm(vmin=vmin, vmax=vmax)
  13. fig, ax = plt.subplots(figsize=(8,8),nrows=2)
  14. ax_ = sns.heatmap(matrix, ax=ax[0],square=True, vmax=vmax, vmin=vmin, norm=log_norm,
  15. cbar_kws={"ticks": cbar_ticks, "format": formatter},cmap="jet",
  16. annot=True,annot_kws={'size':6,'weight':'bold', 'color':'w'})
  17. ax_.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
  18. ax__ = sns.heatmap(matrix, ax=ax[1],square=True, vmax=vmax, vmin=vmin, norm=log_norm,
  19. cbar_kws={"ticks": cbar_ticks, "format": formatter},cmap="jet",
  20. annot=True,annot_kws={'size':6,'weight':'bold', 'color':'w'})
  21. ax__.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
  22. plt.show()

  • vmax,vmin, 用于锚定色彩映射的值,否则它们是从数据和其他关键字参数推断出来的。也是图例中最大值和最小值的显示值,没有该参数时默认不显示,同时也是显示的颜色映射。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/450216
推荐阅读
相关标签
  

闽ICP备14008679号