赞
踩
它并没有那么复杂,它们都将噪声从一些简单分布转换为数据样本,Diffusion也是从纯噪声开始通过一个神经网络学习逐步去噪,最终得到一个实际图像
def rearrange(head, inputs):
b, hc, x, y = inputs.shape
c = hc // head
return inputs.reshape((b, head, c, x * y))
def rsqrt(x):
res = ops.sqrt(x)
return ops.inv(res)
def randn_like(x, dtype=None):
if dtype is None:
dtype = x.dtype
res = ops.standard_normal(x.shape).astype(dtype)
return res
def randn(shape, dtype=None):
if dtype is None:
dtype = ms.float32
res = ops.standard_normal(shape).astype(dtype)
return res
def randint(low, high, size, dtype=ms.int32):
res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)
return res
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def _check_dtype(d1, d2):
if ms.float32 in (d1, d2):
return ms.float32
if d1 == d2:
return d1
raise ValueError(‘dtype is not supported.’)
class Residual(nn.Cell):
def init(self, fn):
super().init()
self.fn = fn
def construct(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
这些是辅助的方法
def Upsample(dim):
return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode=“pad”, padding=1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, pad_mode=“pad”, padding=1)
上面是上,下采样
由于噪声是水平的,那么位置就用sin来表示
class SinusoidalPositionEmbeddings(nn.Cell):
def init(self, dim):
super().init()
self.dim = dim
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = np.exp(np.arange(half_dim) * - emb)
self.emb = Tensor(emb, ms.float32)
def construct(self, x):
emb = x[:, None] * self.emb[None, :]
emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
return emb
class Block(nn.Cell):
def init(self, dim, dim_out, groups=1):
super().init()
self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode=“pad”, padding=1)
self.proj = c(dim, dim_out, 3, padding=1, pad_mode=‘pad’)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def construct(self, x, scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ConvNextBlock(nn.Cell):
def init(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
super().init()
self.mlp = (
nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad") self.net = nn.SequentialCell( nn.GroupNorm(1, dim) if norm else nn.Identity(), nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"), nn.GELU(), nn.GroupNorm(1, dim_out * mult), nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"), ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def construct(self, x, time_emb=None): h = self.ds_conv(x) if exists(self.mlp) and exists(time_emb): assert exists(time_emb), "time embedding must be passed in" condition = self.mlp(time_emb) condition = condition.expand_dims(-1).expand_dims(-1) h = h + condition h = self.net(h) return h + self.res_conv(x)
这哦深奥了。但是就是构建unet
unet是图像的编解码器,可以捕捉细节
class Attention(nn.Cell):
def init(self, dim, heads=4, dim_head=32):
super().init()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True) self.map = ops.Map() self.partial = ops.Partial() def construct(self, x): b, _, h, w = x.shape qkv = self.to_qkv(x).chunk(3, 1) q, k, v = self.map(self.partial(rearrange, self.heads), qkv) q = q * self.scale # 'b h d i, b h d j -> b h i j' sim = ops.bmm(q.swapaxes(2, 3), k) attn = ops.softmax(sim, axis=-1) # 'b h i j, b h d j -> b h i d' out = ops.bmm(attn, v.swapaxes(2, 3)) out = out.swapaxes(-1, -2).reshape((b, -1, h, w)) return self.to_out(out)
class LayerNorm(nn.Cell):
def init(self, dim):
super().init()
self.g = Parameter(initializer(‘ones’, (1, dim, 1, 1)), name=‘g’)
def construct(self, x):
eps = 1e-5
var = x.var(1, keepdims=True)
mean = x.mean(1, keep_dims=True)
return (x - mean) * rsqrt((var + eps)) * self.g
class LinearAttention(nn.Cell):
def init(self, dim, heads=4, dim_head=32):
super().init()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode=‘valid’, has_bias=False)
self.to_out = nn.SequentialCell( nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True), LayerNorm(dim) ) self.map = ops.Map() self.partial = ops.Partial() def construct(self, x): b, _, h, w = x.shape qkv = self.to_qkv(x).chunk(3, 1) q, k, v = self.map(self.partial(rearrange, self.heads), qkv) q = ops.softmax(q, -2) k = ops.softmax(k, -1) q = q * self.scale v = v / (h * w) # 'b h d n, b h e n -> b h d e' context = ops.bmm(k, v.swapaxes(2, 3)) # 'b h d e, b h d n -> b h e n' out = ops.bmm(context.swapaxes(2, 3), q) out = out.reshape((b, -1, h, w)) return self.to_out(out) 这是注意力模块,也就是网络的权重吧 class PreNorm(nn.Cell): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.GroupNorm(1, dim) def construct(self, x): x = self.norm(x) return self.fn(x) 把U-Net的卷积/注意层与群归一化
class Unet(nn.Cell):
def init(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
convnext_mult=2,
):
super().init()
self.channels = channels init_dim = default(init_dim, dim // 3 * 2) self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) block_klass = partial(ConvNextBlock, mult=convnext_mult) if with_time_emb: time_dim = dim * 4 self.time_mlp = nn.SequentialCell( SinusoidalPositionEmbeddings(dim), nn.Dense(dim, time_dim), nn.GELU(), nn.Dense(time_dim, time_dim), ) else: time_dim = None self.time_mlp = None self.downs = nn.CellList([]) self.ups = nn.CellList([]) num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( nn.CellList( [ block_klass(dim_in, dim_out, time_emb_dim=time_dim), block_klass(dim_out, dim_out, time_emb_dim=time_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Downsample(dim_out) if not is_last else nn.Identity(), ] ) ) mid_dim = dims[-1] self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append( nn.CellList( [ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), block_klass(dim_in, dim_in, time_emb_dim=time_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Upsample(dim_in) if not is_last else nn.Identity(), ] ) ) out_dim = default(out_dim, channels) self.final_conv = nn.SequentialCell( block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1) ) def construct(self, x, time): x = self.init_conv(x) t = self.time_mlp(time) if exists(self.time_mlp) else None h = [] for block1, block2, attn, downsample in self.downs: x = block1(x, t) x = block2(x, t) x = attn(x) h.append(x) x = downsample(x) x = self.mid_block1(x, t) x = self.mid_attn(x) x = self.mid_block2(x, t) len_h = len(h) - 1 for block1, block2, attn, upsample in self.ups: x = ops.concat((x, h[len_h]), 1) len_h -= 1 x = block1(x, t) x = block2(x, t) x = attn(x) x = upsample(x) return self.final_conv(x) 总是就是为了把各个零件合在一起称为大网络 正向扩散 def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return np.linspace(beta_start, beta_end, timesteps).astype(np.float32) 正向传播就是加噪声。
timesteps = 200
betas = linear_beta_schedule(timesteps=timesteps)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)
sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)
def extract(a, t, x_shape):
b = t.shape[0]
out = Tensor(a).gather(t, -1)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
url = ‘https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip’
path = download(url, ‘./’, kind=“zip”, replace=True)
但是这里就只有1张照片啊。
为什么要做这么复杂的操作
通过上面的代码正向了,加了噪声了。
每一步加的噪声后是不一样的
去噪声的就是unet,他学习了如何区分噪声和真实的有意义的图。
训练开始
url = ‘https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip’
path = download(url, ‘./’, kind=“zip”, replace=True)
这些是一些小的衣服图
transforms = [
RandomHorizontalFlip(),
ToTensor(),
lambda t: (t * 2) - 1
]
dataset = dataset.project(‘image’)
dataset = dataset.shuffle(64)
dataset = dataset.map(transforms, ‘image’)
dataset = dataset.batch(16, drop_remainder=True)
训练这个unet
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)
unet_model = Unet(
dim=image_size,
channels=channels,
dim_mults=(1, 2, 4,)
)
name_list = []
for (name, par) in list(unet_model.parameters_and_names()):
name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):
item.name = name_list[i]
i += 1
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)
这个是调整loss,不让他太大或者太小。
def forward_fn(data, t, noise=None):
loss = p_losses(unet_model, data, t, noise)
return loss
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
def train_step(data, t, noise):
loss, grads = grad_fn(data, t, noise)
optimizer(grads)
return loss
这和以往的训练类似
那么正向传播的就是模糊的图
由于像素太低,你得用代码变小,才看得出来这是个衣服
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。