当前位置:   article > 正文

Transformer得到的图像attention map可视化,同时叠加到原图

attention map可视化

transformer输出的attention map形状为(bs,q,k)其中bs为batch size,q为query的序列长(这里设为16),k为key的序列长(这里表示图像feature的patch数192=24*8)

已知attention map=bs,16,192

import matplotlib.pyplot as plt
for i in range(bs):  #遍历每一个bs进行可视化
	att1 = (att[i, 0, :] ).view(24, 8).unsqueeze(0).unsqueeze(0) #这里取query的第一行与所有图像patch的attention,也可以可视化query的其它行,然后将192个patch复原为24*8的形状,同时用unsqueeze扩出两维得到1,1,24,8
    att1 = F.interpolate(att1,size=[384,128],mode='bilinear') #将attention插值放大到与原图一样的大小得到1,1,384,128
    im = Image.open(img_path[i]) #读取图像路径,打开图像
    im = im.resize((128, 384)) #图像resize到384长,128宽
    plt.imshow(im)#设置plt可视化图层为原图
    plt.imshow(att1.squeeze().cpu().numpy(),alpha=0.4,cmap='rainbow')#这行将attention图叠加显示,透明度0.4
    plt.axis('off')#关闭坐标轴
                   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

得到的原图叠加attetnion效果图如下
在这里插入图片描述

2、语言attention可视化
例tokens=[‘the’,‘men’,‘in’,‘red’]

import seaborn as sns
heatmap = sns.heatmap((att.unsqueeze(0)).cpu().numpy(),cmap='Blues') 
plt.xticks(range(len(tokens)),tokens[:-1], rotation=45) #设置坐标轴标签
plt.show()
  • 1
  • 2
  • 3
  • 4
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/酷酷是懒虫/article/detail/764373
推荐阅读
相关标签
  

闽ICP备14008679号