当前位置:   article > 正文

语义分割CCNet-Criss Cross Network论文中注意力机制Criss Cross Attention模块的tensorflow代码实现_交叉注意力机制tensorflow代码

交叉注意力机制tensorflow代码

Criss Cross Attention 模块的tensorflow代码实现

也是边学习边写代码,如有问题和指正,请联系!!!

模块结构

在这里插入图片描述

Affinity 操作

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras


class criss_cross_attention_Affinity(tf.keras.layers.Layer):

    def __init__(self, axis=1, **kwargs):
        super(criss_cross_attention_Affinity, self).__init__(**kwargs)
        self.axis = axis

    def call(self, x):
        batch_size, H, W, Channel = x.shape
        outputs = []
        for i in range(H):
            for j in range(W):
                ver = x[:, i, j, :]
                temp_x = tf.concat([x[:, i, 0:j, :], x[:, i, j + 1:W, :], x[:, :, j, :]], axis=1)
                trans_temp = tf.matmul(temp_x, tf.expand_dims(ver, -1))
                trans_temp = tf.squeeze(trans_temp, -1)
                trans_temp = tf.expand_dims(trans_temp, axis=1)
                outputs.append(trans_temp)
        outputs = layers.Concatenate(axis=self.axis)(outputs)
        C = outputs.shape[2]
        outputs = tf.reshape(outputs, [-1, H, W, C])
        return outputs

    def get_config(self):
        config = {'axis': self.axis}
        base_config = super(criss_cross_attention_Affinity, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

Aggregation操作

class criss_cross_attention_Aggregation(tf.keras.layers.Layer):

    def __init__(self, axis=1, **kwargs):
        super(criss_cross_attention_Aggregation, self).__init__(**kwargs)
        self.axis = axis

    def call(self, x, Affinity):
        batch_size, H, W, Channel = x.shape
        Affinity = layers.Activation('softmax')(Affinity)
        outputs = []
        for i in range(H):
            for j in range(W):
                ver = Affinity[:, i, j, :]
                temp_x = tf.concat([x[:, i, 0:j, :], x[:, i, j + 1:W, :], x[:, :, j, :]], axis=1)
                trans_temp = tf.matmul(tf.transpose(tf.expand_dims(ver, -1), [0, 2, 1]), temp_x)
                outputs.append(trans_temp)
        outputs = layers.Concatenate(axis=self.axis)(outputs)
        C = outputs.shape[2]
        outputs = tf.reshape(outputs, [-1, H, W, C])
        return outputs

    def get_config(self):
        config = {'axis': self.axis}
        base_config = super(criss_cross_attention_Aggregation, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

合并两个操作

def criss_cross_attention(x):
    x = layers.Conv2D(filters=64, kernel_size=3, padding='same', strides=2)(x)
    x_origin = x
    affinity = criss_cross_attention_Affinity(1)(x)
    out = criss_cross_attention_Aggregation(1)(x, affinity)
    out = layers.Add()([out, x_origin])
    out = layers.UpSampling2D(size=2, interpolation='bilinear')(out)
    return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

模型打印结果

在这里插入图片描述

问题

由于该模块是针对每一个像素点在原特征图上对应像素所在的十字行列像素上进行计算,所以代码写的是循环遍历每一个像素。导致计算复杂,暂时还没能解决这个问题。
在这里插入图片描述

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号