赞
踩
mindspore打卡第几天 DDPM 之Unet 网络解析
A:
为啥DDPM的unet网络的下采样这部分的channel是从20 32 64 128这样上升的?从U形结构看不应该是下降的
{Block1 --> block2 --> Res(attn)-- >dowmsample}×3
B:
他是在weight和hight上是下降的,通道数是上升
在上采样部分反过来,weight和hight变大,通道数最后回到3
### 条件U-Net
我们已经定义了所有的构建块(位置嵌入、ResNet/ConvNeXT块、Attention和组归一化),现在需要定义整个神经网络了。请记住,网络 $\mathbf{\epsilon}_\theta(\mathbf{x}_t, t)$ 的工作是接收一批噪声图像+噪声水平,并输出添加到输入中的噪声。
更具体的:
网络获取了一批`(batch_size, num_channels, height, width)`形状的噪声图像和一批`(batch_size, 1)`形状的噪音水平作为输入,并返回`(batch_size, num_channels, height, width)`形状的张量。
网络构建过程如下:
- 首先,将卷积层应用于噪声图像批上,并计算噪声水平的位置
- 接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成
- 在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织
- 接下来,应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成
- 最后,应用ResNet/ConvNeXT块,然后应用卷积层
最终,神经网络将层堆叠起来,就像它们是乐高积木一样(但重要的是[了解它们是如何工作的](http://karpathy.github.io/2019/04/25/recipe/))。
```python
class Unet(nn.Cell):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
convnext_mult=2,
):
super().__init__()
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ConvNextBlock, mult=convnext_mult)
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.SequentialCell(
SinusoidalPositionEmbeddings(dim),
nn.Dense(dim, time_dim),
nn.GELU(),
nn.Dense(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
self.downs = nn.CellList([])
self.ups = nn.CellList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.CellList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(
nn.CellList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)
out_dim = default(out_dim, channels)
self.final_conv = nn.SequentialCell(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)
def construct(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
len_h = len(h) - 1
for block1, block2, attn, upsample in self.ups:
x = ops.concat((x, h[len_h]), 1)
len_h -= 1
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
```
```python
import mindspore as ms
from mindspore.common.initializer import Normal
# 参数定义
image_side_length = 32 # 图像的宽和高的像素数
channels = 3 # 图像通道数,这里假设处理的是RGB图像
batch_size = 2 # 批次大小
# 定义 Unet模型
# 注意:此处的dim应该根据模型设计具体指定,但基于您的代码,我们保持原样
unet_model = Unet(dim=image_side_length, channels=channels, dim_mults=(1, 2, 4,))
# 构建输入数据
x = ms.Tensor(shape=(batch_size, channels, image_side_length, image_side_length), dtype=ms.float32, init=Normal())
x.shape # 显示数据形状
print(x) # 打印数据(显示初始化后的随机值)
```
[[[[ 1.22990236e-02 9.65940859e-03 -5.95777121e-04 ... -1.09354462e-02
2.30002552e-02 -5.25823655e-03]
[ 1.35805225e-02 1.16471937e-02 -1.20973922e-02 ... -1.13204606e-02
-1.91520341e-02 -1.09745166e-03]
[-4.65569133e-03 1.33861918e-02 -1.60518996e-02 ... 4.18792450e-04
9.22567211e-03 4.44417645e-04]
...
[ 3.40697076e-03 4.53335233e-03 5.73999388e-03 ... 4.67619160e-03
-8.16432573e-03 -1.39179081e-02]
[-9.07978602e-03 -6.43689744e-03 1.32928183e-02 ... 4.21820907e-03
-1.05559649e-02 8.33686162e-03]
[ 2.96656298e-03 -7.44550209e-03 5.52403228e-03 ... -2.09826510e-03
2.17068940e-02 2.28530783e-02]]
[[-2.34551495e-03 7.68061494e-03 8.63175746e-03 ... -5.62175177e-03
-9.85390134e-03 -4.08322597e-03]
[ 1.30044697e-02 -9.87336412e-03 2.55680992e-03 ... 1.21581517e-02
1.10829184e-02 -1.09381862e-02]
[-1.09032113e-02 1.25320591e-02 -9.15124733e-03 ... -8.42134352e-04
-3.48115107e-03 -8.12307373e-03]
...
[-1.22983279e-02 2.11556954e-03 -1.63072231e-03 ... -8.83890502e-03
2.00234205e-02 -2.91514886e-03]
[-4.95374482e-03 -1.51413877e-03 6.57585217e-03 ... 1.93616766e-02
-3.65696964e-03 -1.76955778e-02]
[ 8.47856048e-03 9.17020999e-03 -5.66793000e-03 ... -2.92802905e-03
-5.98460436e-03 8.32138583e-03]]
[[ 1.00378189e-02 -2.43024575e-03 2.11097375e-02 ... -6.47504721e-03
-1.47426147e-02 7.38033140e-03]
[-3.09416349e-03 -3.46184568e-03 -7.74018466e-03 ... 1.19950040e-03
3.14799254e-04 -7.95779750e-03]
[ 3.98837449e-03 2.33123749e-02 1.63442008e-02 ... 1.05365906e-02
-1.44729228e-03 1.90633966e-03]
...
[-1.76522471e-02 9.42215510e-03 -9.92319733e-03 ... -8.83952528e-03
-1.18930812e-03 -8.53374321e-03]
[ 2.51283534e-02 -1.38457380e-02 1.32035371e-02 ... 1.66724548e-02
-9.26751085e-03 1.42328264e-02]
[-3.69384699e-03 6.09130273e-03 -2.94976344e-04 ... 7.72336172e-03
-3.75742209e-03 -3.17590404e-03]]]
[[[-2.92081665e-03 -1.39991604e-02 -8.93703103e-03 ... 1.51352473e-02
3.90937366e-03 2.66693830e-02]
[-2.27847677e-02 3.63694108e-03 2.70780316e-03 ... -8.13330431e-03
-4.17956570e-03 1.22072157e-02]
[-1.24624427e-02 4.75015305e-03 2.68556597e-03 ... 6.48784591e-03
-6.09957753e-03 4.85362951e-03]
...
[-3.67846363e-03 -9.81856976e-03 -7.40657933e-03 ... 1.95454084e-03
1.80558003e-02 4.30267537e-03]
[-2.47061905e-02 1.53471017e-03 -2.55961739e-03 ... -6.16029697e-03
-1.19128199e-02 7.23672146e-03]
[-9.77169070e-03 -5.93968621e-03 -1.16010886e-02 ... 1.13449963e-02
7.74116023e-03 -8.25872459e-03]]
[[ 2.42574494e-02 -1.59016773e-02 4.60586976e-03 ... -1.27300173e-02
-2.08083801e-02 1.20891845e-02]
[ 4.98928130e-03 1.58587005e-02 -1.17553072e-02 ... -4.57813032e-03
2.66204093e-04 -1.80527139e-02]
[ 9.97055881e-03 2.07035127e-03 -7.31401029e-04 ... 1.80852767e-02
-2.09929375e-03 4.49541025e-04]
...
[-8.71989876e-04 7.75372284e-03 3.14102072e-05 ... 6.37980178e-04
-1.68553423e-02 -4.13572555e-03]
[ 6.12246012e-03 -1.88669516e-03 1.50548946e-02 ... 9.18534491e-03
1.46157937e-02 5.96544426e-03]
[-5.24167530e-03 2.64895801e-03 7.25612324e-03 ... -5.48065547e-03
-2.98001780e-03 -7.99621455e-03]]
[[ 1.18518099e-02 1.00414380e-02 -3.00463289e-03 ... -3.48429219e-03
1.21912286e-02 -8.21612682e-03]
[ 9.25556850e-03 -1.57560236e-04 7.71128759e-03 ... 3.91136715e-03
1.56383701e-02 8.09505815e-04]
[ 4.79864981e-03 1.88933630e-02 1.73798949e-02 ... 5.97322173e-03
4.30198200e-03 1.52684944e-02]
...
[-9.37487371e-03 5.54391975e-03 4.64118691e-03 ... 6.41342625e-03
1.36971334e-03 -1.25444317e-02]
[-4.26448090e-03 7.79700419e-03 2.39845295e-03 ... -1.18866842e-02
3.74738523e-03 1.07039241e-02]
[-1.02939839e-02 7.36899953e-03 -2.00587343e-02 ... -1.10042403e-02
-1.42604960e-02 -1.37462756e-02]]]]
```python
x.shape # 显示数据形状
```
(2, 3, 32, 32)
```python
dim=image_side_length
channels=channels
dim_mults=(1, 2, 4,)
```
```python
init_dim=None
out_dim=None
# dim_mults=(1, 2, 4, 8)
channels=3
with_time_emb=True
convnext_mult=2
```
```python
init_dim = default(init_dim, dim // 3 * 2)
dim,init_dim
```
(32, 20)
```python
channels
```
3
```python
init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)
init_conv
```
Conv2d<input_channels=3, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2d30>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefedb2be0>, format=NCHW>
```python
dim, dim_mults,init_dim
```
(32, (1, 2, 4), 20)
```python
(lambda m: dim * m, dim_mults)
```
(<function __main__.<lambda>(m)>, (1, 2, 4))
```python
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
dims
```
[20, 32, 64, 128]
```python
zip(dims[:-1], dims[1:])
```
<zip at 0xfffefc367b80>
```python
dims[:-1], dims[1:]
```
([20, 32, 64], [32, 64, 128])
```python
in_out = list(zip(dims[:-1], dims[1:]))
in_out
```
[(20, 32), (32, 64), (64, 128)]
```python
ConvNextBlock,convnext_mult
```
(__main__.ConvNextBlock, 2)
```python
block_klass = partial(ConvNextBlock, mult=convnext_mult) ##传入ConvNextBlock的第一个参数mult=convnext_mult
block_klass
```
functools.partial(<class '__main__.ConvNextBlock'>, mult=2)
```python
with_time_emb
```
True
```python
```
```python
time_dim = dim * 4
time_dim,dim
```
(128, 32)
```python
time_mlp = nn.SequentialCell(
SinusoidalPositionEmbeddings(dim),
nn.Dense(dim, time_dim),
nn.GELU(),
nn.Dense(time_dim, time_dim),
)
time_mlp
```
SequentialCell<
(0): SinusoidalPositionEmbeddings<>
(1): Dense<input_channels=32, output_channels=128, has_bias=True>
(2): GELU<>
(3): Dense<input_channels=128, output_channels=128, has_bias=True>
>
```python
downs = nn.CellList([])
ups = nn.CellList([])
ups
```
CellList<>
```python
num_resolutions = len(in_out)
num_resolutions
```
3
```python
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
print(ind,":",is_last)
```
0 : False
1 : False
2 : True
```python
dim_in, dim_out, time_dim ###把每个时间步编码为128维度
```
(64, 128, 128)
```python
in_out
```
[(20, 32), (32, 64), (64, 128)]
```python
for ind, (dim_in, dim_out) in enumerate(in_out):
print(dim_in, dim_out)
is_last = ind >= (num_resolutions - 1)
downs.append(
nn.CellList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
```
20 32
32 64
64 128
```python
downs
```
CellList<
(0): CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=20, has_bias=True>
>
(ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=20>
(1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=32>
>
>
(3): Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
>
(1): CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=64>
>
>
(3): Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
>
(2): CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=128, has_bias=True>
>
(ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=128>
(1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=128>
>
>
(3): Identity<>
>
>
```python
mid_dim = dims[-1]
mid_dim
```
128
```python
mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
```
```python
mid_block1
```
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=128, has_bias=True>
>
(ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354b50>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=128>
(1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17e50>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac36db20>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
```python
mid_attn
```
Residual<
(fn): PreNorm<
(fn): Attention<
(to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac36d9a0>, bias_init=None, format=NCHW>
(to_out): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac36d880>, bias_init=<mindspore.common.initializer.Uniform object at 0xffffac36d730>, format=NCHW>
>
(norm): GroupNorm<num_groups=1, num_channels=128>
>
>
```python
mid_block2
```
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=128, has_bias=True>
>
(ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cee0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=128>
(1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c850>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c910>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
```python
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
print(dim_in, dim_out)
is_last = ind >= (num_resolutions - 1)
print(is_last)
```
64 128
False
32 64
False
```python
dim_in
```
32
```python
LinearAttention(dim_in)
```
LinearAttention<
(to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306eb0>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac348c40>, bias_init=<mindspore.common.initializer.Uniform object at 0xffffac2bc4c0>, format=NCHW>
(1): LayerNorm<>
>
>
```python
class LinearAttention(nn.Cell):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
self.to_out = nn.SequentialCell(
nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
LayerNorm(dim)
)
self.map = ops.Map()
self.partial = ops.Partial()
def construct(self, x):
b, _, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, 1)
q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
q = ops.softmax(q, -2)
k = ops.softmax(k, -1)
q = q * self.scale
v = v / (h * w)
# 'b h d n, b h e n -> b h d e'
context = ops.bmm(k, v.swapaxes(2, 3))
# 'b h d e, b h d n -> b h e n'
out = ops.bmm(context.swapaxes(2, 3), q)
out = out.reshape((b, -1, h, w))
return self.to_out(out)
```
```python
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
ups.append(
nn.CellList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)
```
```python
ups
```
CellList<
(0): CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=256, has_bias=True>
>
(ds_conv): Conv2d<input_channels=256, output_channels=256, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=256, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2b80>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=256>
(1): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee15430>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee15d30>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=256, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb24c0>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbcd00>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc550>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c4f0>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354580>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cc10>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc30c9a0>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=64>
>
>
(3): Conv2dTranspose<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cac0>, bias_init=None, format=NCHW>
>
(1): CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=128, has_bias=True>
>
(ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c2e0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=128>
(1): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c5e0>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17f10>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17760>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17a00>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17580>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17dc0>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3064f0>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306b50>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc3065b0>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=32>
>
>
(3): Conv2dTranspose<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306f70>, bias_init=None, format=NCHW>
>
>
```python
out_dim = default(out_dim, channels)
out_dim
```
3
```python
final_conv = nn.SequentialCell(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)
final_conv
```
SequentialCell<
(0): ConvNextBlock<
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2acb50>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17bb0>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2100>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(1): Conv2d<input_channels=32, output_channels=3, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2eb0>, bias_init=None, format=NCHW>
>
```python
x
time=5
```
```python
x.shape
```
(2, 3, 32, 32)
```python
print(x)
```
[[[[ 1.22990236e-02 9.65940859e-03 -5.95777121e-04 ... -1.09354462e-02
2.30002552e-02 -5.25823655e-03]
[ 1.35805225e-02 1.16471937e-02 -1.20973922e-02 ... -1.13204606e-02
-1.91520341e-02 -1.09745166e-03]
[-4.65569133e-03 1.33861918e-02 -1.60518996e-02 ... 4.18792450e-04
9.22567211e-03 4.44417645e-04]
...
[ 3.40697076e-03 4.53335233e-03 5.73999388e-03 ... 4.67619160e-03
-8.16432573e-03 -1.39179081e-02]
[-9.07978602e-03 -6.43689744e-03 1.32928183e-02 ... 4.21820907e-03
-1.05559649e-02 8.33686162e-03]
[ 2.96656298e-03 -7.44550209e-03 5.52403228e-03 ... -2.09826510e-03
2.17068940e-02 2.28530783e-02]]
[[-2.34551495e-03 7.68061494e-03 8.63175746e-03 ... -5.62175177e-03
-9.85390134e-03 -4.08322597e-03]
[ 1.30044697e-02 -9.87336412e-03 2.55680992e-03 ... 1.21581517e-02
1.10829184e-02 -1.09381862e-02]
[-1.09032113e-02 1.25320591e-02 -9.15124733e-03 ... -8.42134352e-04
-3.48115107e-03 -8.12307373e-03]
...
[-1.22983279e-02 2.11556954e-03 -1.63072231e-03 ... -8.83890502e-03
2.00234205e-02 -2.91514886e-03]
[-4.95374482e-03 -1.51413877e-03 6.57585217e-03 ... 1.93616766e-02
-3.65696964e-03 -1.76955778e-02]
[ 8.47856048e-03 9.17020999e-03 -5.66793000e-03 ... -2.92802905e-03
-5.98460436e-03 8.32138583e-03]]
[[ 1.00378189e-02 -2.43024575e-03 2.11097375e-02 ... -6.47504721e-03
-1.47426147e-02 7.38033140e-03]
[-3.09416349e-03 -3.46184568e-03 -7.74018466e-03 ... 1.19950040e-03
3.14799254e-04 -7.95779750e-03]
[ 3.98837449e-03 2.33123749e-02 1.63442008e-02 ... 1.05365906e-02
-1.44729228e-03 1.90633966e-03]
...
[-1.76522471e-02 9.42215510e-03 -9.92319733e-03 ... -8.83952528e-03
-1.18930812e-03 -8.53374321e-03]
[ 2.51283534e-02 -1.38457380e-02 1.32035371e-02 ... 1.66724548e-02
-9.26751085e-03 1.42328264e-02]
[-3.69384699e-03 6.09130273e-03 -2.94976344e-04 ... 7.72336172e-03
-3.75742209e-03 -3.17590404e-03]]]
[[[-2.92081665e-03 -1.39991604e-02 -8.93703103e-03 ... 1.51352473e-02
3.90937366e-03 2.66693830e-02]
[-2.27847677e-02 3.63694108e-03 2.70780316e-03 ... -8.13330431e-03
-4.17956570e-03 1.22072157e-02]
[-1.24624427e-02 4.75015305e-03 2.68556597e-03 ... 6.48784591e-03
-6.09957753e-03 4.85362951e-03]
...
[-3.67846363e-03 -9.81856976e-03 -7.40657933e-03 ... 1.95454084e-03
1.80558003e-02 4.30267537e-03]
[-2.47061905e-02 1.53471017e-03 -2.55961739e-03 ... -6.16029697e-03
-1.19128199e-02 7.23672146e-03]
[-9.77169070e-03 -5.93968621e-03 -1.16010886e-02 ... 1.13449963e-02
7.74116023e-03 -8.25872459e-03]]
[[ 2.42574494e-02 -1.59016773e-02 4.60586976e-03 ... -1.27300173e-02
-2.08083801e-02 1.20891845e-02]
[ 4.98928130e-03 1.58587005e-02 -1.17553072e-02 ... -4.57813032e-03
2.66204093e-04 -1.80527139e-02]
[ 9.97055881e-03 2.07035127e-03 -7.31401029e-04 ... 1.80852767e-02
-2.09929375e-03 4.49541025e-04]
...
[-8.71989876e-04 7.75372284e-03 3.14102072e-05 ... 6.37980178e-04
-1.68553423e-02 -4.13572555e-03]
[ 6.12246012e-03 -1.88669516e-03 1.50548946e-02 ... 9.18534491e-03
1.46157937e-02 5.96544426e-03]
[-5.24167530e-03 2.64895801e-03 7.25612324e-03 ... -5.48065547e-03
-2.98001780e-03 -7.99621455e-03]]
[[ 1.18518099e-02 1.00414380e-02 -3.00463289e-03 ... -3.48429219e-03
1.21912286e-02 -8.21612682e-03]
[ 9.25556850e-03 -1.57560236e-04 7.71128759e-03 ... 3.91136715e-03
1.56383701e-02 8.09505815e-04]
[ 4.79864981e-03 1.88933630e-02 1.73798949e-02 ... 5.97322173e-03
4.30198200e-03 1.52684944e-02]
...
[-9.37487371e-03 5.54391975e-03 4.64118691e-03 ... 6.41342625e-03
1.36971334e-03 -1.25444317e-02]
[-4.26448090e-03 7.79700419e-03 2.39845295e-03 ... -1.18866842e-02
3.74738523e-03 1.07039241e-02]
[-1.02939839e-02 7.36899953e-03 -2.00587343e-02 ... -1.10042403e-02
-1.42604960e-02 -1.37462756e-02]]]]
```python
init_conv
```
Conv2d<input_channels=3, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2d30>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefedb2be0>, format=NCHW>
```python
x = init_conv(x)
print(x)
```
[[[[ 0.04106301 0.04308875 0.03753628 ... 0.03978373 0.04020362
0.03793241]
[ 0.04131693 0.04225756 0.03624208 ... 0.04158005 0.04384746
0.0374498 ]
[ 0.03251923 0.04382189 0.03682814 ... 0.04343156 0.03790728
0.03667009]
...
[ 0.03987011 0.04261249 0.03721504 ... 0.03754282 0.03530194
0.04190454]
[ 0.03995614 0.04259038 0.04231969 ... 0.03937387 0.03802945
0.03542861]
[ 0.03724413 0.03895703 0.03808391 ... 0.04210365 0.03843816
0.03887339]]
[[-0.07880677 -0.081793 -0.08021648 ... -0.07915598 -0.08803446
-0.07824855]
[-0.07975532 -0.0827108 -0.08153103 ... -0.08920732 -0.08202183
-0.07717112]
[-0.080375 -0.08304221 -0.07943083 ... -0.08371484 -0.07717931
-0.07678773]
...
[-0.07925861 -0.07035945 -0.07607639 ... -0.08380341 -0.08219168
-0.08388805]
[-0.07702561 -0.07861231 -0.08642116 ... -0.08342467 -0.07647635
-0.08471077]
[-0.08218312 -0.08206419 -0.0820056 ... -0.0710914 -0.08050337
-0.08665174]]
[[-0.05770465 -0.05971097 -0.06042907 ... -0.05981689 -0.05351909
-0.06045758]
[-0.06280089 -0.06072729 -0.06125656 ... -0.06167236 -0.05607811
-0.06504007]
[-0.05974955 -0.06224146 -0.05134789 ... -0.06194806 -0.05703649
-0.05661972]
...
[-0.0587245 -0.06006888 -0.06369887 ... -0.0509633 -0.05987025
-0.05689852]
[-0.05888586 -0.06178844 -0.06245932 ... -0.06076533 -0.05802548
-0.06169396]
[-0.05935856 -0.05726556 -0.05836396 ... -0.06468105 -0.05601557
-0.05411654]]
...
[[-0.06885257 -0.06496602 -0.07227325 ... -0.06768468 -0.07973982
-0.06684067]
[-0.06921483 -0.07310341 -0.07145415 ... -0.07373261 -0.06769554
-0.06564213]
[-0.07235637 -0.08390911 -0.06977317 ... -0.06690352 -0.06286541
-0.06959118]
...
[-0.07000594 -0.06508094 -0.06877656 ... -0.07407243 -0.07690564
-0.06396648]
[-0.07082649 -0.07268029 -0.07315704 ... -0.06758922 -0.06662212
-0.06855071]
[-0.07199463 -0.06999994 -0.07014568 ... -0.06523817 -0.07094447
-0.07466151]]
[[ 0.06986891 0.06941634 0.06439675 ... 0.06187101 0.0675493
0.07306495]
[ 0.07319107 0.07266034 0.05997508 ... 0.06689761 0.06815154
0.0660945 ]
[ 0.07065641 0.05923657 0.06411441 ... 0.06652149 0.07088953
0.07194202]
...
[ 0.06848673 0.07591817 0.0726023 ... 0.06602401 0.06890585
0.07259338]
[ 0.07433689 0.06679939 0.06691605 ... 0.0667197 0.07184143
0.06983658]
[ 0.06431621 0.07212089 0.06723586 ... 0.06868842 0.07140361
0.06901537]]
[[ 0.02097934 0.01210438 0.01431934 ... 0.01505992 0.01852277
0.01381299]
[ 0.02296696 0.0177606 0.01976403 ... 0.02147305 0.02210259
0.02313221]
[ 0.02124698 0.02709681 0.02910981 ... 0.01016832 0.02212639
0.01957588]
...
[ 0.02289374 0.01311012 0.01578637 ... 0.01931083 0.01555186
0.0208313 ]
[ 0.01390727 0.02096656 0.01745579 ... 0.01781181 0.02211875
0.01568411]
[ 0.02439262 0.01495296 0.01968778 ... 0.02193322 0.01783368
0.0176824 ]]]
[[[ 0.04223164 0.03378314 0.03601065 ... 0.03959855 0.03485664
0.04071919]
[ 0.0318558 0.05363872 0.03783617 ... 0.04385335 0.0496259
0.03691863]
[ 0.03818982 0.03180957 0.04072122 ... 0.03430039 0.03384047
0.03837577]
...
[ 0.0398777 0.03721025 0.03533046 ... 0.04020133 0.03928016
0.04710523]
[ 0.04118172 0.03496882 0.03100736 ... 0.03642647 0.03914004
0.0371574 ]
[ 0.04037436 0.04040184 0.04165599 ... 0.04403537 0.03254044
0.04335065]]
[[-0.08552387 -0.07319534 -0.08021338 ... -0.07858572 -0.07166487
-0.08406518]
[-0.07923919 -0.08566054 -0.08015955 ... -0.08471547 -0.0847266
-0.08085599]
[-0.08489675 -0.09258271 -0.08831957 ... -0.09042192 -0.08426952
-0.0808774 ]
...
[-0.08036023 -0.07413588 -0.07989521 ... -0.07935498 -0.08571334
-0.08329107]
[-0.07644836 -0.07608277 -0.08767064 ... -0.08434241 -0.08071237
-0.0839122 ]
[-0.07979399 -0.08087463 -0.08673595 ... -0.08414597 -0.08045428
-0.07299927]]
[[-0.06060546 -0.05453672 -0.06102112 ... -0.05194974 -0.0567053
-0.06273571]
[-0.06276039 -0.05693425 -0.04725159 ... -0.06214722 -0.06443968
-0.05762123]
[-0.05252658 -0.06019294 -0.06137866 ... -0.04910715 -0.06131132
-0.06036767]
...
[-0.06173272 -0.05464447 -0.05099018 ... -0.06136036 -0.06400239
-0.06106843]
[-0.05803053 -0.05994222 -0.06404369 ... -0.04949801 -0.05738675
-0.06158596]
[-0.05899998 -0.06198164 -0.05937162 ... -0.06379396 -0.06430338
-0.06287489]]
...
[[-0.07173873 -0.0707745 -0.06975999 ... -0.07155637 -0.06534318
-0.07189398]
[-0.06730746 -0.07013785 -0.06751848 ... -0.07264671 -0.07705939
-0.07342067]
[-0.07058413 -0.07025788 -0.06871852 ... -0.06887744 -0.06563742
-0.07028291]
...
[-0.07809374 -0.06778216 -0.06392691 ... -0.06867532 -0.07118014
-0.07647338]
[-0.07219965 -0.07040192 -0.0732589 ... -0.07633238 -0.0752567
-0.0702922 ]
[-0.06984755 -0.07723872 -0.06846898 ... -0.06786713 -0.06702175
-0.07062964]]
[[ 0.06747851 0.06883495 0.06797507 ... 0.06853593 0.06575806
0.06841848]
[ 0.06514458 0.06994057 0.06866109 ... 0.06339982 0.06309478
0.06588745]
[ 0.06701669 0.0691862 0.06725767 ... 0.06696404 0.07045414
0.07060774]
...
[ 0.07085218 0.0809648 0.06841429 ... 0.06838602 0.06918488
0.07014886]
[ 0.07304276 0.07134987 0.07214254 ... 0.07656243 0.07136226
0.06578355]
[ 0.06968872 0.07193028 0.06518821 ... 0.07004035 0.06891351
0.06959624]]
[[ 0.02295301 0.01347421 0.02212771 ... 0.02214386 0.01323562
0.02334489]
[ 0.01578862 0.01825874 0.01307945 ... 0.0216907 0.02719616
0.02306023]
[ 0.01491401 0.01406 0.02918804 ... 0.02165697 0.01733657
0.01930147]
...
[ 0.01614039 0.019646 0.02148937 ... 0.00664111 0.01888491
0.02413018]
[ 0.01757734 0.01567486 0.01912338 ... 0.02099028 0.01717271
0.01547725]
[ 0.01756483 0.020161 0.01650484 ... 0.01933268 0.0167334
0.01855144]]]]
```python
x.shape
```
(2, 20, 32, 32)
```python
#unet_model.init_conv(x) ####调用实例化后类的方法!!!!! 好像是失败的
```
```python
import numpy as np
from mindspore import Tensor
# 定义时间步的起始值、步数以及步长(默认为1,即每个时间步增加1)
start = 0
num_steps = 10
step = 1
# 生成线性递增的时间步长序列
t = Tensor(np.arange(start, start + num_steps * step, step), dtype=ms.int32)
print("线性递增的时间步长序列:", t)
# time_mlp
# SequentialCell<
# (0): SinusoidalPositionEmbeddings<>
# (1): Dense<input_channels=32, output_channels=128, has_bias=True>
# (2): GELU<>
# (3): Dense<input_channels=128, output_channels=128, has_bias=True>
# >
time_mlp(t) ###这里正确了
```
线性递增的时间步长序列: [0 1 2 3 4 5 6 7 8 9]
Tensor(shape=[10, 128], dtype=Float32, value=
[[ 1.89717814e-01, 1.14008449e-02, 3.33061777e-02 ... 1.43985003e-01, 3.92933972e-02, -1.06829256e-02],
[ 1.93240538e-01, 1.78442001e-02, 6.77158684e-02 ... 1.36301309e-01, 7.64560923e-02, -1.50307640e-02],
[ 1.83035284e-01, 2.44393535e-02, 8.70461762e-02 ... 1.38745904e-01, 1.33171901e-01, -3.85175534e-02],
...
[ 1.28773689e-01, 1.91335917e-01, -9.48226005e-02 ... 8.54851380e-02, 1.52098373e-01, 2.03581899e-02],
[ 1.22549936e-01, 1.48201510e-01, -8.17623138e-02 ... 4.44053262e-02, 9.75183249e-02, 3.97774130e-02],
[ 1.07752994e-01, 1.08763084e-01, -7.05250949e-02 ... 4.34711799e-02, 6.16942756e-02, 1.67786255e-02]])
```python
import numpy as np
from mindspore import Tensor
# 定义时间步的起始值、步数以及步长(默认为1,即每个时间步增加1)
start = 0
num_steps = 5
step = 1
# 生成线性递增的时间步长序列
t = Tensor(np.arange(start, start + num_steps * step, step), dtype=ms.int32)
print("线性递增的时间步长序列:", t)
# time_mlp
# SequentialCell<
# (0): SinusoidalPositionEmbeddings<>
# (1): Dense<input_channels=32, output_channels=128, has_bias=True>
# (2): GELU<>
# (3): Dense<input_channels=128, output_channels=128, has_bias=True>
# >
time_mlp(t) ###这里正确了
```
线性递增的时间步长序列: [0 1 2 3 4]
Tensor(shape=[5, 128], dtype=Float32, value=
[[ 1.89717814e-01, 1.14008449e-02, 3.33061777e-02 ... 1.43985003e-01, 3.92933972e-02, -1.06829256e-02],
[ 1.93240538e-01, 1.78442001e-02, 6.77158684e-02 ... 1.36301309e-01, 7.64560923e-02, -1.50307640e-02],
[ 1.83035284e-01, 2.44393535e-02, 8.70461762e-02 ... 1.38745904e-01, 1.33171901e-01, -3.85175534e-02],
[ 1.60673216e-01, 5.66724911e-02, 6.34887069e-02 ... 1.58708930e-01, 1.91956162e-01, -7.78784081e-02],
[ 1.33557051e-01, 1.20526701e-01, 5.27272746e-03 ... 1.78237736e-01, 2.34055102e-01, -9.55123529e-02]])
```python
t.shape
```
(5,)
看起来你希望构造一个形状为`(5,)`的张量`t`,这通常表示一个包含5个元素的一维向量。在Python中使用MindSpore库,你可以很容易地创建这样一个张量。这里有几个创建此形状张量的例子,包括初始化为特定值或随机值:
### 初始化为零
```python
import mindspore as ms
# 创建一个全零的张量,形状为(5,)
t = ms.Tensor.zeros(5, dtype=ms.float32)
print(t.shape) # 输出: (5,)
```
### 初始化为一
```python
t = ms.Tensor.ones(5, dtype=ms.float32)
print(t.shape) # 输出: (5,)
```
### 初始化为单位序列(例如,[0, 1, 2, 3, 4])
```python
import numpy as np
t = ms.Tensor(np.arange(5), dtype=ms.float32)
print(t.shape) # 输出: (5,)
```
### 初始化为随机值
```python
t = ms.Tensor(np.random.rand(5), dtype=ms.float32)
print(t.shape) # 输出: (5,)
```
### 使用特定值填充
```python
value = 3.14 # 例如,使用π作为填充值
t = ms.Tensor(np.full(5, value), dtype=ms.float32)
print(t.shape) # 输出: (5,)
```
以上任一代码块都可以创建一个形状为`(5,)`的张量`t`,根据你的具体需求选择合适的初始化方式。
```python
h = []
```
```python
downs # 3#套有3个cellist元素的一个celllist
```
CellList<
(0): CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=20, has_bias=True>
>
(ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=20>
(1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=32>
>
>
(3): Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
>
(1): CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=64>
>
>
(3): Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
>
(2): CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=128, has_bias=True>
>
(ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=128>
(1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=128>
>
>
(3): Identity<>
>
>
```python
for downsample in downs:
print(downsample)
```
CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=20, has_bias=True>
>
(ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=20>
(1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=32>
>
>
(3): Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
>
CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=64>
>
>
(3): Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
>
CellList<
(0): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
>
(1): ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=128, has_bias=True>
>
(ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=128>
(1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(2): Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=128>
>
>
(3): Identity<>
>
```python
for block1, block2, attn, downsample in downs:
print("aaaaaaaaaas11BL1")
print(block1)
print("aaaaaaaaaasBL2")
print( block2)
print("aaaaaaaaaasAT 残差 attn")
print(attn)
print("aaaaaaaaaasAT 下采样")
print( downsample)
```
aaaaaaaaaas11BL1
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=20, has_bias=True>
>
(ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=20>
(1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
>
aaaaaaaaaasBL2
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
aaaaaaaaaasAT 残差 attn
Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=32>
>
>
aaaaaaaaaasAT 下采样
Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
aaaaaaaaaas11BL1
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
>
aaaaaaaaaasBL2
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
aaaaaaaaaasAT 残差 attn
Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=64>
>
>
aaaaaaaaaasAT 下采样
Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
aaaaaaaaaas11BL1
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
>
aaaaaaaaaasBL2
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=128, has_bias=True>
>
(ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=128>
(1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
aaaaaaaaaasAT 残差 attn
Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=128>
>
>
aaaaaaaaaasAT 下采样
Identity<>
```python
```
```python
dim_in, dim_out,dim
```
(32, 64, 32)
## 因为循环这个[(20, 32), (32, 64), (64, 128)] 所以down 有3个元素 nn.cell
```python
convnext_mult
```
2
```python
#dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
dims
```
[20, 32, 64, 128]
```python
i=0
for block1, block2, attn, downsample in downs:
i=i+1
print("--------",i)
print("BL1块1:",block1,"BL2块2:", block2, "ATTT残差注意力:",attn, "DOWN下采样:",downsample)
# block_klass(dim_in, dim_out, time_emb_dim=time_dim), #time_dim=128 dim=128
# block_klass(dim_out, dim_out, time_emb_dim=time_dim),
# Residual(PreNorm(dim_out, LinearAttention(dim_out))),
# Downsample(dim_out) if not is_last else nn.Identity(),
block_klass = partial(ConvNextBlock, mult=convnext_mult) ##传入ConvNextBlock的第一个参数mult=convnext_mult
block_klass
# class ConvNextBlock(nn.Cell):
# def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
# super().__init__()
# self.mlp = (
# nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
# if exists(time_emb_dim)
# else None
# )
# self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
# self.net = nn.SequentialCell(
# nn.GroupNorm(1, dim) if norm else nn.Identity(),
# nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
# nn.GELU(),
# nn.GroupNorm(1, dim_out * mult),
# nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
# )
# self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
# def construct(self, x, time_emb=None):
# h = self.ds_conv(x)
# if exists(self.mlp) and exists(time_emb):
# assert exists(time_emb), "time embedding must be passed in"
# condition = self.mlp(time_emb)
# condition = condition.expand_dims(-1).expand_dims(-1)
# h = h + condition
# h = self.net(h)
# return h + self.res_conv(x)
### 第一层BL1
#input_channels=20, output_channels=32 (20,32)
### 第一层BL2
#Conv2d<input_channels=64【32*2】, output_channels=32, (res_conv): Identity<> (32,32)
### 第一层Res
#Conv2d<input_channels=128, output_channels=32, (32,32)
### 第一层down
#Conv2d<input_channels=32, output_channels=32 (32,32)
### 第二层BL1
#Conv2d<input_channels=32, output_channels=64 (32,64)
### 第二层BL2
# Conv2d<input_channels=128【64*2】, output_channels=64, (res_conv): Identity<> (64,64)
### 第二层Res
#Conv2d<input_channels=128, output_channels=64 (64,64)
### 第二层down
# Conv2d<input_channels=64, output_channels=64 (64,64)
### 第三层BL1
#Conv2d<input_channels=64, output_channels=128 (64,128)
### 第三层BL2
#Conv2d<input_channels=256【128*2】, output_channels=128, (res_conv): Identity<> (128,128)
### 第三层Res
#Conv2d<input_channels=128, output_channels=128, (res_conv): Identity<> (128,128)
### 第三层down
# Identity<> (128,128)
```
-------- 1
BL1块1: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=20, has_bias=True>
>
(ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c130>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=20>
(1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc070>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354a90>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3d7cd0>, bias_init=None, format=NCHW>
> BL2块2: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffad3c17f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3062b0>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c190>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
> ATTT残差注意力: Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d90>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc306a60>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=32>
>
> DOWN下采样: Conv2d<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf1f0>, bias_init=None, format=NCHW>
-------- 2
BL1块1: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306fd0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf580>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf190>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=32, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf670>, bias_init=None, format=NCHW>
> BL2块2: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf940>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf910>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cf760>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
> ATTT残差注意力: Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfa30>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfac0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2cfb80>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=64>
>
> DOWN下采样: Conv2d<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfdc0>, bias_init=None, format=NCHW>
-------- 3
BL1块1: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cfe50>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2cff40>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac130>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac250>, bias_init=None, format=NCHW>
> BL2块2: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=128, has_bias=True>
>
(ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac1f0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=128>
(1): Conv2d<input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac370>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=256>
(4): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac4c0>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
> ATTT残差注意力: Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=128, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac6a0>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=128, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2ac7c0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc2ac820>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=128>
>
> DOWN下采样: Identity<>
functools.partial(<class '__main__.ConvNextBlock'>, mult=2)
```python
x.shape
```
(2, 20, 32, 32)
```python
t
```
Tensor(shape=[5], dtype=Int32, value= [0, 1, 2, 3, 4])
```python
t=time_mlp(t) ###这里正确了
t
```
Tensor(shape=[5, 128], dtype=Float32, value=
[[ 1.89717814e-01, 1.14008449e-02, 3.33061777e-02 ... 1.43985003e-01, 3.92933972e-02, -1.06829256e-02],
[ 1.93240538e-01, 1.78442001e-02, 6.77158684e-02 ... 1.36301309e-01, 7.64560923e-02, -1.50307640e-02],
[ 1.83035284e-01, 2.44393535e-02, 8.70461762e-02 ... 1.38745904e-01, 1.33171901e-01, -3.85175534e-02],
[ 1.60673216e-01, 5.66724911e-02, 6.34887069e-02 ... 1.58708930e-01, 1.91956162e-01, -7.78784081e-02],
[ 1.33557051e-01, 1.20526701e-01, 5.27272746e-03 ... 1.78237736e-01, 2.34055102e-01, -9.55123529e-02]])
```python
# 选取第一行
# 选取第一行
new_t = t[0:1, :] # 或者直接 t[0:1] 也可以 = t[0:1, :] # 或者直接 t[0:1] 也可以
new_t.shape
```
(1, 128)
```python
t=new_t
```
```python
class ConvNextBlock(nn.Cell):
def __init__(self, dim=20, dim_out=32, *, time_emb_dim=128, mult=2, norm=True):
super().__init__()
self.mlp = (
nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
self.net = nn.SequentialCell(
nn.GroupNorm(1, dim) if norm else nn.Identity(),
nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
nn.GELU(),
nn.GroupNorm(1, dim_out * mult),
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def construct(self, x, time_emb=None):
h = self.ds_conv(x)
if exists(self.mlp) and exists(time_emb):
assert exists(time_emb), "time embedding must be passed in"
condition = self.mlp(time_emb)
condition = condition.expand_dims(-1).expand_dims(-1)
h = h + condition
h = self.net(h)
return h + self.res_conv(x)
```
```python
BL1=ConvNextBlock()
BL1
```
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=20, has_bias=True>
>
(ds_conv): Conv2d<input_channels=20, output_channels=20, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=20, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffddded2b80>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=20>
(1): Conv2d<input_channels=20, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffddded2a00>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffddded2f70>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=20, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdde372970>, bias_init=None, format=NCHW>
>
```python
x = ms.Tensor(shape=(batch_size, channels, image_side_length, image_side_length), dtype=ms.float32, init=Normal())
x.shape # 显示数据形状
x = init_conv(x)
t=new_t
x.shape,t.shape
```
((2, 20, 32, 32), (1, 128))
```python
BL1(x, t)
```
-
Tensor(shape=[2, 32, 32, 32], dtype=Float32, value=
[[[[ 4.79678512e-01, 6.18093789e-01, 4.11911160e-01 ... 2.45554969e-01, 2.87007272e-01, 5.84303178e-02],
[ 3.35714668e-01, 5.81144929e-01, 3.02089810e-01 ... 7.34893382e-02, -2.16056317e-01, -2.40196183e-01],
[ 6.84010506e-01, 1.10967433e+00, 7.40820885e-01 ... 4.42677915e-01, -5.60586452e-02, -1.65627971e-01],
...
[ 7.63499856e-01, 1.13718486e+00, 9.87868309e-01 ... 6.83883190e-01, 1.12375900e-01, -3.52420285e-02],
[ 7.06175983e-01, 1.14337587e+00, 9.23162937e-01 ... 6.31385088e-01, 1.49176538e-01, -3.98113094e-02],
[ 1.76896825e-01, 4.82230633e-01, 4.89957243e-01 ... 4.00457889e-01, -2.55417023e-02, 1.21514685e-01]],
[[ 2.81717628e-01, -9.01857391e-04, -8.76490697e-02 ... -7.88121223e-02, -1.26668364e-01, -1.73759460e-01],
[ 1.92831263e-01, -1.43926665e-01, -1.48099199e-01 ... -2.68793881e-01, -1.00307748e-01, -1.20211102e-01],
[ 1.39493421e-01, 2.17211112e-01, 1.45897210e-01 ... 1.56540543e-01, 1.65525198e-01, 5.83062395e-02],
...
[ 6.92536831e-02, 5.68503514e-02, 3.68858390e-02 ... 8.93135369e-02, 1.07637540e-01, 4.47027944e-02],
[ 1.12979397e-01, 2.12710798e-01, 5.37276417e-02 ... 1.01731792e-01, -4.49074507e-02, -2.13617496e-02],
[ 7.63944685e-02, 1.35763273e-01, 1.16834700e-01 ... 1.85339376e-01, 1.34029865e-01, 2.19782546e-01]],
[[ 1.30130202e-01, 3.29991281e-01, 4.25871283e-01 ... 3.08504313e-01, 4.61269379e-01, 1.57225683e-01],
[ 4.05929983e-01, 7.34413862e-01, 8.77515614e-01 ... 7.71579146e-01, 9.44401443e-01, 5.34572124e-01],
[-3.62766758e-02, 3.63330275e-01, 4.08758491e-01 ... 3.54813248e-01, 5.34208298e-01, 2.87856668e-01],
...
[-1.26003683e-01, 2.56464094e-01, 3.78679991e-01 ... 4.59082156e-01, 5.85478425e-01, 2.87693620e-01],
[-4.88909632e-02, 2.99566031e-01, 3.99350137e-01 ... 4.61422205e-01, 4.17674005e-01, 7.26606846e-02],
[-2.39267662e-01, -3.09484214e-01, -2.71493912e-01 ... -7.24366307e-02, -1.12498447e-01, -1.38472736e-01]],
...
[[-1.59212112e-01, 7.95098245e-02, -1.85586754e-02 ... -2.23550811e-01, -2.70033002e-01, -2.44036630e-01],
[-1.60980105e-01, 4.35432374e-01, 5.99099815e-01 ... 4.49325353e-01, 4.23938036e-01, 3.12254220e-01],
[-3.05827290e-01, 7.60348588e-02, 2.39996284e-01 ... 2.72248350e-02, -1.65684037e-02, -1.06293596e-01],
...
[-2.98552692e-01, 5.39370701e-02, 2.43080735e-01 ... 1.28992423e-01, 6.57526404e-02, -7.48077184e-02],
[-2.99177408e-01, -1.83835357e-01, -3.95203456e-02 ... -4.43453342e-02, -1.39597371e-01, -2.18513533e-01],
[ 1.12659251e-02, 2.71181725e-02, 8.25900063e-02 ... 1.92151055e-01, 2.09751755e-01, 4.28373404e-02]],
[[-1.12641305e-01, 2.13623658e-01, 7.53313601e-02 ... -1.21324155e-02, -1.53158829e-01, -4.77597833e-01],
[-3.84328097e-01, 6.15204051e-02, 1.25263743e-02 ... -7.10279793e-02, -2.77535737e-01, -3.76104653e-01],
[-4.07613993e-01, 1.12187594e-01, 3.72527242e-02 ... -2.06387043e-02, -1.46990225e-01, -2.87585199e-01],
...
[-3.61820847e-01, 9.99522805e-02, -9.61808674e-03 ... -4.89163473e-02, -1.65467933e-01, -3.17837149e-01],
[-6.31257832e-01, -2.93515027e-01, -3.12220454e-01 ... -1.29600003e-01, -1.40924498e-01, -1.52635127e-01],
[-5.44437885e-01, -2.92856932e-01, -2.64693975e-01 ... -1.66876107e-01, -1.02364674e-01, 1.51942633e-02]],
[[-4.31133687e-01, -6.47886336e-01, -8.31129193e-01 ... -8.14953923e-01, -8.53494108e-01, -5.33654928e-01],
[-1.00464332e+00, -1.09428477e+00, -1.38921976e+00 ... -1.53568864e+00, -1.31930661e+00, -4.56116676e-01],
[-1.13990593e+00, -1.03481460e+00, -1.49022770e+00 ... -1.62232399e+00, -1.40410137e+00, -5.24185598e-01],
...
[-1.26637757e+00, -1.16348124e+00, -1.51403701e+00 ... -1.62057757e+00, -1.49686539e+00, -5.62209725e-01],
[-1.11411476e+00, -9.28978443e-01, -1.17068827e+00 ... -1.24776149e+00, -1.03560197e+00, -3.00330132e-01],
[-8.52251232e-01, -7.45468676e-01, -9.34467912e-01 ... -1.00492966e+00, -8.28293085e-01, -2.36020580e-01]]],
[[[ 4.66426373e-01, 6.25464499e-01, 3.93321365e-01 ... 2.30096430e-01, 3.03178400e-01, 5.51086180e-02],
[ 3.32810491e-01, 6.17641032e-01, 3.06311995e-01 ... 1.02875866e-01, -1.90033853e-01, -2.67078996e-01],
[ 6.84333742e-01, 1.09874713e+00, 7.69674480e-01 ... 4.62564647e-01, -6.67065307e-02, -2.36097589e-01],
...
[ 7.33127177e-01, 1.17721725e+00, 9.89053011e-01 ... 7.15648472e-01, 1.17136240e-01, -4.71793935e-02],
[ 6.87817514e-01, 1.09633350e+00, 8.85757685e-01 ... 5.86604059e-01, 9.45525989e-02, -3.36224921e-02],
[ 1.95047542e-01, 4.68445808e-01, 4.64000225e-01 ... 3.75145793e-01, 1.80484354e-03, 1.13696203e-01]],
[[ 2.65101492e-01, 2.46687196e-02, -1.07584536e-01 ... -1.03970490e-01, -8.17846432e-02, -1.53097644e-01],
[ 1.52385280e-01, -8.83764774e-02, -1.62100300e-01 ... -2.18001902e-01, -9.41001922e-02, -1.19305871e-01],
[ 1.61403462e-01, 2.30408147e-01, 1.57331020e-01 ... 1.96940184e-01, 1.30461589e-01, 6.52605519e-02],
...
[ 5.33300415e-02, 1.22396260e-01, -1.88096687e-02 ... 1.05915010e-01, 1.53571054e-01, 2.45359484e-02],
[ 9.76279452e-02, 1.82655305e-01, 1.09691672e-01 ... 1.40925452e-01, 8.01324844e-04, 1.88996084e-03],
[ 7.78941736e-02, 1.48156703e-01, 1.39126211e-01 ... 2.34847367e-01, 1.08238310e-01, 2.09336147e-01]],
[[ 1.17656320e-01, 3.43433738e-01, 4.39827234e-01 ... 3.09850901e-01, 4.53984410e-01, 1.49862975e-01],
[ 3.95844698e-01, 7.22729802e-01, 8.56524229e-01 ... 8.03788304e-01, 9.29986835e-01, 5.35356879e-01],
[-5.63383885e-02, 3.74548256e-01, 3.96855712e-01 ... 3.82491171e-01, 5.69522381e-01, 3.37262630e-01],
...
[-1.52335018e-01, 2.06072912e-01, 3.62504959e-01 ... 4.47174758e-01, 5.80285132e-01, 2.62391001e-01],
[-3.93561423e-02, 3.46813500e-01, 3.57039988e-01 ... 4.35864031e-01, 4.56840277e-01, 7.46745914e-02],
[-2.53383636e-01, -2.85494059e-01, -2.50772089e-01 ... -1.19937055e-01, -9.26215500e-02, -1.42144099e-01]],
...
[[-1.72060549e-01, 7.35142902e-02, 1.05164722e-02 ... -2.21164092e-01, -2.66059518e-01, -2.46203467e-01],
[-1.51937515e-01, 4.78927851e-01, 5.69894075e-01 ... 4.43681806e-01, 4.60492224e-01, 2.69292653e-01],
[-3.13104421e-01, 1.40309155e-01, 2.25837916e-01 ... 3.81425694e-02, 7.96409026e-02, -1.00592285e-01],
...
[-3.12582165e-01, 3.55643854e-02, 1.99504092e-01 ... 1.73697099e-01, 5.84969185e-02, -7.21051544e-02],
[-2.97579527e-01, -1.40844122e-01, -5.89616746e-02 ... -2.86213737e-02, -1.22039340e-01, -2.13227138e-01],
[ 3.26608960e-03, 4.80151782e-03, 5.54511398e-02 ... 1.92409024e-01, 1.99357480e-01, 3.34331095e-02]],
[[-1.12870112e-01, 2.09549189e-01, 9.65655223e-02 ... -4.27889600e-02, -1.44986391e-01, -4.36559677e-01],
[-3.58288437e-01, 9.11962241e-02, -8.71371478e-04 ... -4.59572896e-02, -2.30747938e-01, -4.01585191e-01],
[-3.93885374e-01, 1.43090501e-01, -1.07427724e-02 ... -2.74238884e-02, -1.51127338e-01, -3.17271531e-01],
...
[-3.39975446e-01, 6.32377267e-02, -4.78150323e-02 ... -9.10668075e-02, -1.39780402e-01, -2.90815294e-01],
[-6.53750360e-01, -2.34141201e-01, -2.83103675e-01 ... -9.91634205e-02, -1.61574319e-01, -1.63588241e-01],
[-5.44618607e-01, -2.84289837e-01, -2.75803030e-01 ... -1.71445966e-01, -1.29518926e-01, 9.64298844e-04]],
[[-4.26712006e-01, -6.10979497e-01, -8.42772007e-01 ... -8.18627119e-01, -8.19367886e-01, -5.47874212e-01],
[-1.00363958e+00, -1.09864676e+00, -1.45736146e+00 ... -1.57554007e+00, -1.27784061e+00, -5.02054691e-01],
[-1.13520634e+00, -1.06120992e+00, -1.46519005e+00 ... -1.58347833e+00, -1.37640107e+00, -5.30683100e-01],
...
[-1.25580835e+00, -1.18095815e+00, -1.52926147e+00 ... -1.63596940e+00, -1.44870484e+00, -5.40451765e-01],
[-1.12258959e+00, -9.22193348e-01, -1.14803100e+00 ... -1.24472821e+00, -1.02205288e+00, -3.08310509e-01],
[-8.60808074e-01, -7.58719683e-01, -9.61336493e-01 ... -9.92724419e-01, -8.70862603e-01, -2.35298872e-01]]]])
```python
x = ms.Tensor(shape=(batch_size, channels, image_side_length, image_side_length), dtype=ms.float32, init=Normal())
x.shape # 显示数据形状
x = init_conv(x)
t=new_t
x.shape,t.shape ##需要放到一个格子里面才能运算成功
for block1, block2, attn, downsample in downs:
###事实循环3次,每次有这四个变量
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
```
/
```python
x.shape
```
(2, 128, 8, 8)
```python
len(h)
```
3
```python
h ##[x0=2 32 32 32 x1=2 64 16 16 x3=2 128 8 8]
```
[Tensor(shape=[2, 32, 32, 32], dtype=Float32, value=
Tensor(shape=[2, 64, 16, 16], dtype=Float32, value=
Tensor(shape=[2, 128, 8, 8], dtype=Float32, value=
```python
len_h = len(h) - 1
len_h
```
2
```python
h[len_h].shape,x.shape ##最后的一个downsample的维度没有变化
```
((2, 128, 8, 8), (2, 128, 8, 8))
```python
ops.concat((x, h[len_h]), 1)
```
-
Tensor(shape=[2, 256, 8, 8], dtype=Float32,
这段代码使用了MindSpore框架中的`ops.concat`函数来执行张量拼接操作。下面是对该代码片段的详细解析:
- `ops.concat`: 这是MindSpore操作库中的一个函数,用于沿着指定维度拼接一个张量列表。这里的"ops"是MindSpore中操作(operations)的简写,用于访问各种数学和数组操作。
- `(x, h[len_h])`: 这是一个包含两个张量的元组,它们是要被拼接的输入。其中:
- `x` 是一个张量。
- `h[len_h]` 表示从列表或数组`h`中获取索引为`len_h`的元素。这通常意味着取`h`的最后一个元素,如果`len_h`是`h`的长度的话。不过,确切的行为依据`len_h`的具体值而定,如果`len_h`是动态计算的结果或者代表序列的长度,则它可能不是简单地指最后一个元素,而是某个特定位置的元素。
- `, 1)`: 这个参数指定了拼接操作应该沿着第1个维度进行。在MindSpore和其他类似的深度学习库中,维度计数通常从0开始,所以`1`表示第二个维度。这意味着`x`和`h[len_h]`将在它们的第二个维度上被连接起来,生成一个新的张量,其中这两个输入张量的相应列被串联在一起。
综上所述,这段代码的作用是将张量`x`和序列`h`中的最后一个元素(或索引为`len_h`的元素)在第二个维度上进行拼接,从而生成一个新的张量。这样的操作常见于循环神经网络(RNNs)等模型中,用于更新隐藏状态或组合不同来源的信息。
```python
x = mid_block1(x, t)
x = mid_attn(x)
x = mid_block2(x, t)
x.shape
```
\
(2, 128, 8, 8)
```python
len_h = len(h) - 1
len_h
```
2
```python
i=0
for block1, block2, attn, upsample in ups: ##ups 只有2个元素
i=i+1
print("--------",i)
print("BL1块1:",block1,"BL2块2:", block2, "ATTT残差注意力:",attn, "UP上采样:",upsample)
# block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
# block_klass(dim_in, dim_in, time_emb_dim=time_dim),
# Residual(PreNorm(dim_in, LinearAttention(dim_in))),
# Upsample(dim_in) if not is_last else nn.Identity(),
### 第一层BL1
#(res_conv): Conv2d<input_channels=256, output_channels=64 (128,64)
### 第一层BL2
#(4): Conv2d<input_channels=128[64*2], output_channels=64, (res_conv): Identity<> (128,64)
### 第一层Res
#Conv2d<input_channels=128, output_channels=64,, (128,64)
### 第一层up
#Conv2dTranspose<input_channels=64, output_channels=64, (64,64)
### 第二层BL1
#(res_conv): Conv2d<input_channels=128, output_channels=32 (64,32)
### 第二层BL2
# Conv2d<input_channels=64[32*2], output_channels=32, (res_conv): Identity<> (64,32)
### 第二层Res
#Conv2d<input_channels=128, output_channels=32, (64,32)
### 第二层up
# Conv2dTranspose<input_channels=32, output_channels=32, (32,32)
```
-------- 1
BL1块1: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=256, has_bias=True>
>
(ds_conv): Conv2d<input_channels=256, output_channels=256, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=256, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2b80>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=256>
(1): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee15430>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee15d30>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=256, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb24c0>, bias_init=None, format=NCHW>
> BL2块2: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=64, has_bias=True>
>
(ds_conv): Conv2d<input_channels=64, output_channels=64, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=64, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbcd00>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=64>
(1): Conv2d<input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedbc550>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c4f0>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
> ATTT残差注意力: Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=64, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xffffac354580>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cc10>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc30c9a0>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=64>
>
> UP上采样: Conv2dTranspose<input_channels=64, output_channels=64, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30cac0>, bias_init=None, format=NCHW>
-------- 2
BL1块1: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=128, has_bias=True>
>
(ds_conv): Conv2d<input_channels=128, output_channels=128, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=128, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c2e0>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=128>
(1): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc30c5e0>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17f10>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17760>, bias_init=None, format=NCHW>
> BL2块2: ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=32, has_bias=True>
>
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17a00>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17580>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17dc0>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
> ATTT残差注意力: Residual<
(fn): PreNorm<
(fn): LinearAttention<
(to_qkv): Conv2d<input_channels=32, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc3064f0>, bias_init=None, format=NCHW>
(to_out): SequentialCell<
(0): Conv2d<input_channels=128, output_channels=32, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306b50>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffefc3065b0>, format=NCHW>
(1): LayerNorm<>
>
>
(norm): GroupNorm<num_groups=1, num_channels=32>
>
> UP上采样: Conv2dTranspose<input_channels=32, output_channels=32, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc306f70>, bias_init=None, format=NCHW>
```python
class ConvNextBlock(nn.Cell):
def __init__(self, dim=256, dim_out=64, *, time_emb_dim=128, mult=2, norm=True):
super().__init__()
self.mlp = (
nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
self.net = nn.SequentialCell(
nn.GroupNorm(1, dim) if norm else nn.Identity(),
nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
nn.GELU(),
nn.GroupNorm(1, dim_out * mult),
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def construct(self, x, time_emb=None):
h = self.ds_conv(x)
if exists(self.mlp) and exists(time_emb):
assert exists(time_emb), "time embedding must be passed in"
condition = self.mlp(time_emb)
condition = condition.expand_dims(-1).expand_dims(-1)
h = h + condition
h = self.net(h)
return h + self.res_conv(x)
```
```python
BL1=ConvNextBlock()
BL1
```
ConvNextBlock<
(mlp): SequentialCell<
(0): GELU<>
(1): Dense<input_channels=128, output_channels=256, has_bias=True>
>
(ds_conv): Conv2d<input_channels=256, output_channels=256, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=256, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdddd94880>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=256>
(1): Conv2d<input_channels=256, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdb504af10>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=128>
(4): Conv2d<input_channels=128, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdb53c3d60>, bias_init=None, format=NCHW>
>
(res_conv): Conv2d<input_channels=256, output_channels=64, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffdcc3e9970>, bias_init=None, format=NCHW>
>
```python
```
```python
for block1, block2, attn, upsample in ups:
x = ops.concat((x, h[len_h]), 1)
len_h -= 1
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x.shape
```
(2, 32, 32, 32)
```python
rx=final_conv(x)
rx.shape
```
(2, 3, 32, 32)
```python
def construct(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
len_h = len(h) - 1
for block1, block2, attn, upsample in self.ups: ###因为up只有2个元素 down 有3个元素 但是我们在这里只是循环2次 并没有取出h[0] 就是downn 最开始的那个(20,32)
x = ops.concat((x, h[len_h]), 1) ##这步是有啥作用?就是传说的skip connect 或者所谓的残差?
len_h -= 1
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
```
这段代码定义了一个基于U-Net架构的模型,主要应用于图像处理、图像生成或分割任务中,特别是在需要保留细节信息同时捕捉上下文特征的场景下。此模型通过编码器-解码器结构,结合跳跃连接(skip connections)来实现这一点。下面是逐部分的解析:
### 初始化与时间嵌入
- `self.init_conv(x)`:对输入`x`应用初始卷积层,开始特征提取。
- `t = self.time_mlp(time)`:如果模型设计中包含时间相关的处理(常用于时序数据或在生成模型中引入时间条件),则通过多层感知机(MLP)处理时间信号`time`,得到时间嵌入`t`。
### 编码器路径(Downsampling)
- 循环遍历`self.downs`中的模块(每个模块包含两个卷积块`block1, block2`、一个注意力模块`attn`和一个下采样模块`downsample`):
- 两个卷积块分别应用特征变换,并可选择性地结合时间嵌入`t`。
- 应用注意力机制模块`attn`增强特征表示。
- 将当前特征图`x`添加到列表`h`中作为跳跃连接的存储。
- 使用下采样模块减小空间尺寸,增加深度。
### 中间块(Bottleneck)
- 应用一系列中间层变换,包括两个卷积块和一个注意力模块,进一步提炼特征。
### 解码器路径(Upsampling)
- 反向循环遍历`self.ups`中的模块,与编码器部分相对应,但包含上采样操作:
- `ops.concat((x, h[len_h]), 1)`:这是关键的跳跃连接步骤,将当前解码器层的输出`x`与对应编码器层的特征图`h[len_h]`沿通道维度(维度1)拼接,从而传递并合并局部细节信息。
- 减少`len_h`以在下一轮迭代中获取上一层的特征图。
- 继续应用两个卷积块和注意力模块,以及上采样操作,逐步增加空间尺寸并整合信息。
### 输出
- 最后,通过`self.final_conv(x)`应用最终的卷积层,生成输出特征图或直接预测像素级结果。
### 跳跃连接(Skip Connections)的作用
- 跳跃连接(在这里体现为特征图的拼接)有助于解决梯度消失问题,使得网络能更有效地学习细节信息。
- 它允许低级特征(保留了更多细节信息)与高级特征(提供了更多上下文信息)在解码阶段融合,这对于恢复输入的精细结构至关重要,尤其是在图像生成和分割任务中。
- 因此,这种设计不仅有助于保持对输入细节的精确重构,还能促进生成内容的高保真度和清晰度。
```python
final_conv ##恢复到3个channel
```
SequentialCell<
(0): ConvNextBlock<
(ds_conv): Conv2d<input_channels=32, output_channels=32, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=32, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefc2acb50>, bias_init=None, format=NCHW>
(net): SequentialCell<
(0): GroupNorm<num_groups=1, num_channels=32>
(1): Conv2d<input_channels=32, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefee17bb0>, bias_init=None, format=NCHW>
(2): GELU<>
(3): GroupNorm<num_groups=1, num_channels=64>
(4): Conv2d<input_channels=64, output_channels=32, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2100>, bias_init=None, format=NCHW>
>
(res_conv): Identity<>
>
(1): Conv2d<input_channels=32, output_channels=3, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffefedb2eb0>, bias_init=None, format=NCHW>
>
```python
```
```python
```
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。