赞
踩
transformer输出的attention map形状为(bs,q,k)其中bs为batch size,q为query的序列长(这里设为16),k为key的序列长(这里表示图像feature的patch数192=24*8)
已知attention map=bs,16,192
import matplotlib.pyplot as plt
for i in range(bs): #遍历每一个bs进行可视化
att1 = (att[i, 0, :] ).view(24, 8).unsqueeze(0).unsqueeze(0) #这里取query的第一行与所有图像patch的attention,也可以可视化query的其它行,然后将192个patch复原为24*8的形状,同时用unsqueeze扩出两维得到1,1,24,8
att1 = F.interpolate(att1,size=[384,128],mode='bilinear') #将attention插值放大到与原图一样的大小得到1,1,384,128
im = Image.open(img_path[i]) #读取图像路径,打开图像
im = im.resize((128, 384)) #图像resize到384长,128宽
plt.imshow(im)#设置plt可视化图层为原图
plt.imshow(att1.squeeze().cpu().numpy(),alpha=0.4,cmap='rainbow')#这行将attention图叠加显示,透明度0.4
plt.axis('off')#关闭坐标轴
得到的原图叠加attetnion效果图如下
2、语言attention可视化
例tokens=[‘the’,‘men’,‘in’,‘red’]
import seaborn as sns
heatmap = sns.heatmap((att.unsqueeze(0)).cpu().numpy(),cmap='Blues')
plt.xticks(range(len(tokens)),tokens[:-1], rotation=45) #设置坐标轴标签
plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。