赞
踩
★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>
视觉Transformer由于其高度的建模能力而取得了巨大的成功。 然而,它们的卓越性能伴随着沉重的计算代价,这使得它们不适合于实时应用。 在本文中,我们提出了一个高速视觉变换器族,命名为EfficientViT。 我们发现,现有的Transformer模型的速度通常受到访存效率低的操作的限制,尤其是在MHSA中的张量重塑和逐元素函数。 因此,我们设计了一种新的三明治布局的构建块,即在有效的FFN层之间使用单个内存受限的MHSA,在增强通道通信的同时提高了访存效率。 此外,我们发现注意力图在头部之间有很高的相似性,导致计算冗余。 为了解决这一问题,我们提出了一个级联的分组注意力模块,给注意力头提供全特征的不同划分,不仅节省了计算开销,而且提高了注意力的多样性。 广泛实验证明EfficientViT优于现有的高效模型,在速度和精度之间取得了良好的平衡。 例如,我们的效率EffilientViT-M5在精确度上比MobileNetV3-Large高1.9%,而在英伟达V100 GPU和英特尔至强CPU上的吞吐量分别高出40.4%和45.2%。 与最近的高效机型MobileVit-XXS相比,EffilientViT-M2在GPU/CPU上运行速度快5.8×/3.7×,转换为ONNX格式时速度快7.4×,精确度高1.8%。
如图2所示,本文首先分析了DeiT和Swin两个架构的运行时间分析,发现Transformer架构的速度通常受限于访存。针对这一问题本文提出了一种三明治架构,即2N个FFN中间中间夹一个MHSA结构的级联分组注意力。同时本文发现现有的多头划分方法导致每个头的注意力图高度相似,这造成了资源的浪费,本文提出了一种新的多头划分策略来缓解这一问题。
本文通过分析DeiT和Swin两个Transformer架构得出如下结论:
本文的整体框架如图6所示,包含三个阶段,每个阶段包含若干个三明治结构,三明治结构由2N个DWConv(空间局部通信)和FFN(信道通信)以及级联分组注意力构成。级联分组注意力相对于之前的MHSA不同之处在于先划分头部然后再生成Q、K、V。同时为了学习更丰富的特征映射来提高模型容量,本文将每个头的输出与下一个头的输入相加。最后将多个头输出Concat起来,使用一个线性层进行映射得到最终的输出,用公式表示为:
X
~
i
j
=
Attn
(
X
i
j
W
i
j
Q
,
X
i
j
W
i
j
K
,
X
i
j
W
i
j
V
)
X
~
i
+
1
=
Concat
[
X
~
i
j
]
j
=
1
:
h
W
i
P
X
i
j
′
=
X
i
j
+
X
~
i
(
j
−
1
)
,
1
<
j
≤
h
!pip install paddlex
%matplotlib inline import paddle import paddle.fluid as fluid import numpy as np import matplotlib.pyplot as plt from paddle.vision.datasets import Cifar10 from paddle.vision.transforms import Transpose from paddle.io import Dataset, DataLoader from paddle import nn import paddle.nn.functional as F import paddle.vision.transforms as transforms import os import matplotlib.pyplot as plt from matplotlib.pyplot import figure import paddlex import itertools
train_tfm = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomRotation(20),
paddlex.transforms.MixupImage(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
test_tfm = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
class LabelSmoothingCrossEntropy(nn.Layer):
def __init__(self, smoothing=0.1):
super().__init__()
self.smoothing = smoothing
def forward(self, pred, target):
confidence = 1. - self.smoothing
log_probs = F.log_softmax(pred, axis=-1)
idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
nll_loss = paddle.gather_nd(-log_probs, index=idx)
smooth_loss = paddle.mean(-log_probs, axis=-1)
loss = confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
def drop_path(x, drop_prob=0.0, training=False): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... """ if drop_prob == 0.0 or not training: return x keep_prob = paddle.to_tensor(1 - drop_prob) shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) random_tensor = paddle.floor(random_tensor) # binarize output = x.divide(keep_prob) * random_tensor return output class DropPath(nn.Layer): def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training)
class Conv2D_BN(nn.Sequential): def __init__(self, in_channel, out_channel, ks=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1): super().__init__() self.add_sublayer('conv', nn.Conv2D(in_channel, out_channel, ks, stride=stride, padding=padding, groups=groups, dilation=dilation)) self.add_sublayer('bn', nn.BatchNorm2D(out_channel)) init = nn.initializer.Constant(bn_weight_init) init(self.bn.weight) zero = nn.initializer.Constant(0) zero(self.bn.bias) class BN_Linear(nn.Sequential): def __init__(self, in_channel, out_channel,bias=True, std=0.02): super().__init__() self.add_sublayer('bn', nn.BatchNorm1D(in_channel)) self.add_sublayer('linear', nn.Linear(in_channel, out_channel, bias_attr=bias)) tn = nn.initializer.TruncatedNormal(std=std) tn(self.linear.weight) if bias: zero = nn.initializer.Constant(0.0) zero(self.linear.bias)
class SqueezeExcite(nn.Layer): def __init__( self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer=nn.Sigmoid): super().__init__() self.fc1 = nn.Conv2D(channels, int(channels * rd_ratio), kernel_size=1, bias_attr=bias) self.act = act_layer() self.fc2 = nn.Conv2D(int(channels * rd_ratio), channels, kernel_size=1, bias_attr=bias) self.gate = gate_layer() def forward(self, x): x_se = x.mean((2, 3), keepdim=True) x_se = self.fc1(x_se) x_se = self.act(x_se) x_se = self.fc2(x_se) return x * self.gate(x_se) class PatchMerging(nn.Layer): def __init__(self, dim, out_dim): super().__init__() hid_dim = int(dim * 4) self.conv1 = Conv2D_BN(dim, hid_dim, 1, 1, 0) self.act = nn.ReLU() self.conv2 = Conv2D_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim) self.se = SqueezeExcite(hid_dim, .25) self.conv3 = Conv2D_BN(hid_dim, out_dim, 1, 1, 0) def forward(self, x): x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x)))))) return x
class Residual(nn.Layer):
def __init__(self, m, drop=0.):
super().__init__()
self.m = m
self.dropout = nn.Dropout(drop)
def forward(self, x):
return x + self.dropout(self.m(x))
class FFN(nn.Layer):
def __init__(self, ed, h):
super().__init__()
self.pw1 = Conv2D_BN(ed, h)
self.act = nn.ReLU()
self.pw2 = Conv2D_BN(h, ed, bn_weight_init=0)
def forward(self, x):
x = self.pw2(self.act(self.pw1(x)))
return x
class CascadedGroupAttention(nn.Layer): def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=14, kernels=[5, 5, 5, 5]): super().__init__() self.resolution = resolution self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim self.d = int(attn_ratio * key_dim) self.attn_ratio = attn_ratio qkvs = [] dws = [] for i in range(num_heads): qkvs.append(Conv2D_BN(dim // (num_heads), self.key_dim * 2 + self.d)) dws.append(Conv2D_BN(self.key_dim, self.key_dim, kernels[i], stride=1, padding=kernels[i]//2, groups=self.key_dim)) self.qkvs = nn.LayerList(qkvs) self.dws = nn.LayerList(dws) self.proj = nn.Sequential(nn.ReLU(), Conv2D_BN( self.d * num_heads, dim, bn_weight_init=0)) points = list(itertools.product(range(resolution), range(resolution))) N = len(points) self.N = N attention_offsets = {} idxs = [] for p1 in points: for p2 in points: offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) self.attention_biases = self.create_parameter((len(attention_offsets), num_heads), default_initializer=nn.initializer.Constant(0.0)) self.attention_bias_idxs = idxs def forward(self, x): # x (B,C,H,W) B, C, H, W = x.shape trainingab = self.attention_biases[self.attention_bias_idxs].transpose((1, 0)).reshape((self.num_heads, self.N, self.N)) feats_in = paddle.chunk(x, len(self.qkvs), axis=1) feats_out = [] feat = feats_in[0] for i, qkv in enumerate(self.qkvs): if i > 0: # add the previous output to the input feat = feat + feats_in[i] feat = qkv(feat) q, k, v = feat.split([self.key_dim, self.key_dim, self.d], axis=1) # B, C/h, H, W q = self.dws[i](q) q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N attn = (q.transpose([0, 2, 1]) @ k) * self.scale attn = attn + trainingab[i] attn = F.softmax(attn, axis=-1) # BNN feat = (v @ attn.transpose([0, 2, 1])).reshape((B, self.d, H, W)) # BCHW feats_out.append(feat) x = self.proj(paddle.concat(feats_out, axis=1)) return x
class LocalWindowAttention(nn.Layer): def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=14, window_resolution=7, kernels=[5, 5, 5, 5],): super().__init__() self.dim = dim self.num_heads = num_heads self.resolution = resolution assert window_resolution > 0, 'window_size must be greater than 0' self.window_resolution = window_resolution window_resolution = min(window_resolution, resolution) self.attn = CascadedGroupAttention(dim, key_dim, num_heads, attn_ratio=attn_ratio, resolution=window_resolution, kernels=kernels,) def forward(self, x): H = W = self.resolution B, C, H_, W_ = x.shape # Only check this for classifcation models assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_)) if H <= self.window_resolution and W <= self.window_resolution: x = self.attn(x) else: x = x.transpose([0, 2, 3, 1]) pad_b = (self.window_resolution - H % self.window_resolution) % self.window_resolution pad_r = (self.window_resolution - W % self.window_resolution) % self.window_resolution padding = pad_b > 0 or pad_r > 0 if padding: x = F.pad(x, (0, pad_r, 0, pad_b)) pH, pW = H + pad_b, W + pad_r nH = pH // self.window_resolution nW = pW // self.window_resolution # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw x = x.reshape((B, nH, self.window_resolution, nW, self.window_resolution, C)).transpose([0, 1, 3, 2, 4, 5]).reshape( (B * nH * nW, self.window_resolution, self.window_resolution, C) ).transpose([0, 3, 1, 2]) x = self.attn(x) # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC x = x.transpose((0, 2, 3, 1)).reshape((B, nH, nW, self.window_resolution, self.window_resolution, C)).transpose([0, 1, 3, 2, 4, 5]).reshape((B, pH, pW, C)) if padding: x = x[:, :H, :W] x = x.transpose([0, 3, 1, 2]) return x
class EfficientViTBlock(nn.Layer): def __init__(self, type, ed, kd, nh=8, ar=4, resolution=14, window_resolution=7, kernels=[5, 5, 5, 5],): super().__init__() self.dw0 = Residual(Conv2D_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.)) self.ffn0 = Residual(FFN(ed, int(ed * 2))) if type == 's': self.mixer = Residual(LocalWindowAttention(ed, kd, nh, attn_ratio=ar, \ resolution=resolution, window_resolution=window_resolution, kernels=kernels)) self.dw1 = Residual(Conv2D_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.)) self.ffn1 = Residual(FFN(ed, int(ed * 2))) def forward(self, x): return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
class EfficientViT(nn.Layer): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, stages=['s', 's', 's'], embed_dim=[64, 128, 192], key_dim=[16, 16, 16], depth=[1, 2, 3], num_heads=[4, 4, 4], window_size=[7, 7, 7], kernels=[5, 5, 5, 5], down_ops=[['subsample', 2], ['subsample', 2], ['']]): super().__init__() resolution = img_size # Patch embedding self.patch_embed = nn.Sequential(Conv2D_BN(in_chans, embed_dim[0] // 8, 3, 2, 1), nn.ReLU(), Conv2D_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1), nn.ReLU(), Conv2D_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), nn.ReLU(), Conv2D_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1)) resolution = img_size // patch_size attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))] self.blocks1 = [] self.blocks2 = [] self.blocks3 = [] # Build EfficientViT blocks for i, (stg, ed, kd, dpth, nh, ar, wd, do) in enumerate( zip(stages, embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)): for d in range(dpth): eval('self.blocks' + str(i+1)).append(EfficientViTBlock(stg, ed, kd, nh, ar, resolution, wd, kernels)) if do[0] == 'subsample': # Build EfficientViT downsample block #('Subsample' stride) blk = eval('self.blocks' + str(i+2)) resolution_ = (resolution - 1) // do[1] + 1 blk.append(nn.Sequential(Residual(Conv2D_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])), Residual(FFN(embed_dim[i], int(embed_dim[i] * 2))))) blk.append(PatchMerging(*embed_dim[i:i + 2])) resolution = resolution_ blk.append(nn.Sequential(Residual(Conv2D_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1])), Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2))))) self.blocks1 = nn.Sequential(*self.blocks1) self.blocks2 = nn.Sequential(*self.blocks2) self.blocks3 = nn.Sequential(*self.blocks3) # Classification head self.head = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() def forward(self, x): x = self.patch_embed(x) x = self.blocks1(x) x = self.blocks2(x) x = self.blocks3(x) x = F.adaptive_avg_pool2d(x, 1).flatten(1) x = self.head(x) return x
num_classes = 10 def EfficientViT_M0(): model = EfficientViT(embed_dim=[64, 128, 192], depth=[1, 2, 3], num_heads=[4, 4, 4], kernels=[5, 5, 5, 5], num_classes=num_classes) return model def EfficientViT_M1(): model = EfficientViT(embed_dim=[128, 144, 192], depth=[1, 2, 3], num_heads=[2, 3, 3], kernels=[7, 5, 3, 3], num_classes=num_classes) return model def EfficientViT_M2(): model = EfficientViT(embed_dim=[128, 192, 224], depth=[1, 2, 3], num_heads=[4, 3, 2], kernels=[7, 5, 3, 3], num_classes=num_classes) return model def EfficientViT_M3(): model = EfficientViT(embed_dim=[128, 240, 320], depth=[1, 2, 3], num_heads=[4, 3, 4], kernels=[5, 5, 5, 5], num_classes=num_classes) return model def EfficientViT_M4(): model = EfficientViT(embed_dim=[128, 256, 384], depth=[1, 2, 3], num_heads=[4, 4, 4], kernels=[7, 5, 3, 3], num_classes=num_classes) return model def EfficientViT_M5(): model = EfficientViT(embed_dim=[192, 288, 384], depth=[1, 3, 4], num_heads=[3, 3, 4], kernels=[7, 5, 3, 3], num_classes=num_classes) return model
model = EfficientViT_M0()
paddle.summary(model, (1, 3, 224, 224))
model = EfficientViT_M1()
paddle.summary(model, (1, 3, 224, 224))
model = EfficientViT_M2()
paddle.summary(model, (1, 3, 224, 224))
model = EfficientViT_M3()
paddle.summary(model, (1, 3, 224, 224))
model = EfficientViT_M4()
paddle.summary(model, (1, 3, 224, 224))
model = EfficientViT_M5()
paddle.summary(model, (1, 3, 224, 224))
learning_rate = 0.001
n_epochs = 100
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model' # EfficientViT-M0 model = EfficientViT_M0() criterion = LabelSmoothingCrossEntropy() scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False) optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5) gate = 0.0 threshold = 0.0 best_acc = 0.0 val_acc = 0.0 loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}} # for recording loss acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}} # for recording accuracy loss_iter = 0 acc_iter = 0 for epoch in range(n_epochs): # ---------- Training ---------- model.train() train_num = 0.0 train_loss = 0.0 val_num = 0.0 val_loss = 0.0 accuracy_manager = paddle.metric.Accuracy() val_accuracy_manager = paddle.metric.Accuracy() print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr())) for batch_id, data in enumerate(train_loader): x_data, y_data = data labels = paddle.unsqueeze(y_data, axis=1) logits = model(x_data) loss = criterion(logits, y_data) acc = paddle.metric.accuracy(logits, labels) accuracy_manager.update(acc) if batch_id % 10 == 0: loss_record['train']['loss'].append(loss.numpy()) loss_record['train']['iter'].append(loss_iter) loss_iter += 1 loss.backward() optimizer.step() scheduler.step() optimizer.clear_grad() train_loss += loss train_num += len(y_data) total_train_loss = (train_loss / train_num) * batch_size train_acc = accuracy_manager.accumulate() acc_record['train']['acc'].append(train_acc) acc_record['train']['iter'].append(acc_iter) acc_iter += 1 # Print the information. print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100)) # ---------- Validation ---------- model.eval() for batch_id, data in enumerate(val_loader): x_data, y_data = data labels = paddle.unsqueeze(y_data, axis=1) with paddle.no_grad(): logits = model(x_data) loss = criterion(logits, y_data) acc = paddle.metric.accuracy(logits, labels) val_accuracy_manager.update(acc) val_loss += loss val_num += len(y_data) total_val_loss = (val_loss / val_num) * batch_size loss_record['val']['loss'].append(total_val_loss.numpy()) loss_record['val']['iter'].append(loss_iter) val_acc = val_accuracy_manager.accumulate() acc_record['val']['acc'].append(val_acc) acc_record['val']['iter'].append(acc_iter) print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100)) # ===================save==================== if val_acc > best_acc: best_acc = val_acc paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams')) paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt')) print(best_acc) paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams')) paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
def plot_learning_curve(record, title='loss', ylabel='CE Loss'): ''' Plot learning curve of your CNN ''' maxtrain = max(map(float, record['train'][title])) maxval = max(map(float, record['val'][title])) ymax = max(maxtrain, maxval) * 1.1 mintrain = min(map(float, record['train'][title])) minval = min(map(float, record['val'][title])) ymin = min(mintrain, minval) * 0.9 total_steps = len(record['train'][title]) x_1 = list(map(int, record['train']['iter'])) x_2 = list(map(int, record['val']['iter'])) figure(figsize=(10, 6)) plt.plot(x_1, record['train'][title], c='tab:red', label='train') plt.plot(x_2, record['val'][title], c='tab:cyan', label='val') plt.ylim(ymin, ymax) plt.xlabel('Training steps') plt.ylabel(ylabel) plt.title('Learning curve of {}'.format(title)) plt.legend() plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')
plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bIbPy1KT-1685187440917)(main_files/main_47_0.png)]
import time
work_path = 'work/model'
model = EfficientViT_M0()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):
x_data, y_data = data
labels = paddle.unsqueeze(y_data, axis=1)
with paddle.no_grad():
logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:1002
def get_cifar10_labels(labels):
"""返回CIFAR10数据集的文本标签。"""
text_labels = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
'horse', 'ship', 'truck']
return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if paddle.is_tensor(img):
ax.imshow(img.numpy())
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if pred or gt:
ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = EfficientViT_M0()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = EfficientViT_M0()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:55<00:00, 181.62it/s]
55<00:00, 181.62it/s]
本文提出的EfficientViT-M0以2.2M的参数在CIFAR上可以达到89.4%的准确率,同时在图像分辨率为224的情况下吞吐量可以达到1002 imgs/s
此文章为搬运
原项目链接
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。