赞
踩
深度学习的典型任务一般认为是分类,检测和分割。
CNN的复兴可以说就是从图像分类开始,imagenet的出现极大推动了这一领域的进步。在这之前,cifar10和mnist也对这个领域的发展起到了极大的推动作用,时至今日,cifar10和mnist也依然不过时,cifar10依旧是一个中等难度的数据集,很多网络在cifar10上的performance已经足以说明该网络的优异了,当年在imagenet上达到sota的resnet同样在cifar10上同样也能达到sota。mnist的作用也一直重要,它是很多人入门深度学习的基础数据集,它的数据集很小,很方便下载存储,几乎可以即下即用,它的训练和收敛速度也非常快,基本全数据集上一个epoch就能达到90%以上的精确率,它的存在还有其他意义 ,如果你实现了一个新的网络结构或者损失函数,在它上面测试一次能帮你很快的定位到bug,一些开发工作,比如量化训练,分布式训练都可以拿它调试bug。
好了,这些数据集都是分类用的,另外两类任务,目前没有特别通用的小型的数据集,对于检测任务,voc,kitti数据集是比较常用的数据集,但是它们都很大,通常达数G甚至数十G,下载缓慢,存储代价高,运行时需要性能极高的机器,不适合用来教学和debug,一些很小的细节也抬高了初学者使用的门槛,就比如voc的数据组织和数据结构,如果没有人带路也不太容易明白。
我在kaggle上搜索了一圈,发现了一个小的分割数据集,名为m2nist,它是mnist的升级版,每张图像包含多余一个数字,但是不超过三个,应该是doublemnist和triplemnist合成而来。
它包含两个文件,combined.npy和segmented.npy,都是numpy格式,可以直接用numpy导入,前者是一个3d张量,每个slice是一个包含数字的图片,其中一幅图片如下所示,
后者是一个4d张量,最后一个维度是11,每个slice代表一个全图的pixelwise的one-hot编码过的mask,最后一个维度的意义是,如果原图包含某个数字,比如6,那么第6个通道,原图中6出现的位置应该是1,其他位置都是0。最后一个通道是背景的mask,意义是如果该位置没有数字出现,应该是1,否则是0。这是因为在分割中背景被当作单独的类别对待,这点尤为重要。上面图像的第0,第4,第7,第10个mask如下所示。
这个数据集时float32数据类型存储的,我把它压缩成了uint8,事实也应该如此。对于segmented图像,我把它压缩成立一个通道,意义为这个位置出现那个数字,比如6,这个位置的数值就是6,背景的数值是10。
同时,我使用联通分量分析,找到了每个数字的bbox,并把bbox数据存储成yolo格式的txt文件,如下所示,
每一行第一个数字表示该标注对应的图像在combined.npy中对应的index,之后每个bbox标注用空白符分隔开,每个bbox标注由五个数字表示,从左向右,分别为xmin,ymin,xmax,ymax和label,即左上和右下的x,y坐标以及该bbox中的目标的label,可视化结果如下所示,
这个数据集压缩后只有2.5M,但是它有5000张图片,足够入门教学使用,它的图片大小为84X64,这点不够友好,因为很多时候我们希望长宽至少是16的倍数,但是也有其他方案,我在我的unet实现中给了一个很好的例子。
在python中使用这个数据集的代码为
- import zipfile
- from six.moves import urllib
- import ssl
- ssl._create_default_https_context = ssl._create_unverified_context
- from tqdm import tqdm
-
-
- def download_from_url(url, dst):
- """
- @param: url to download file
- @param: dst place to put the file
- """
- req = urllib.request.urlopen(url)
- file_size = int(req.info().get('Content-Length', -1))
- if os.path.exists(dst):
- first_byte = os.path.getsize(dst)
- else:
- first_byte = 0
- if first_byte >= file_size:
- return file_size
- header = {"Range": "bytes=%s-%s" % (first_byte, file_size)}
- pbar = tqdm(total=file_size, initial=first_byte, unit='B', unit_scale=True, desc=url.split('/')[-1])
- with (open(dst, 'ab')) as f:
- while True:
- chunk = req.read(1024)
- if not chunk:
- break
- file_size += len(chunk)
- f.write(chunk)
- pbar.update(len(chunk))
- pbar.close()
- return file_size
-
- def download_m2nist_if_not_exist():
- data_rootdir = os.path.expanduser('~/.m2nist')
- m2nist_zip_path = os.path.join(data_rootdir, 'm2nist.zip')
- if os.path.exists(m2nist_zip_path):
- return
- os.makedirs(data_rootdir, exist_ok=True)
- m2nist_zip_url = 'https://raw.githubusercontent.com/akkaze/datasets/master/m2nist.zip'
- while True:
- try:
- print('Trying to download m2nist...')
- download_from_url(m2nist_zip_url, m2nist_zip_path)
- break
- except Exception as exc:
- print('Errors occured : {0}'.format(exc))
- time.sleep(5)
- continue
- zipf = zipfile.ZipFile(m2nist_zip_path)
- zipf.extractall(data_rootdir)
- zipf.close()
它会在用户目录下创建.m2nist目录,并将zip文件下载到这个目录中,希望大家使用愉快。
这个数据集的下载地址为m2nist下载地址,另外我基于此实现的unet地址为unet源码。
unet的收敛曲线如下所示
网络的可训练参数量为25k,收敛很快,可见这个规模的网络对这个数据集依然容易过拟合。
我又继续在yolo3上测试了检测的结果,10个epoch后,能得到如下的结果,
可见,图像质量完全足够用来做检测的,收敛曲线如下(这里只计算了loss,其他metric需要进一步补充),
其中yolo3的代码也使用tf2.x,并开源在了github上,yolo3源码,需要进一步完善。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。