赞
踩
github地址:
https://github.com/Huyf9/mnist_pytorch/
为了规范代码与数据集,因此我们按如下格式创建项目:
mnist
|—— dataset
|__ train
|__ test
ubyte格式数据下载地址:
http://yann.lecun.com/exdb/mnist/
我们需要下载这四个文件:
train-images-idx3-ubyte.gz #training set images
train-labels-idx1-ubyte.gz #training set labels
t10k-images-idx3-ubyte.gz #test set images
t10k-labels-idx1-ubyte.gz #test set labels
这个文件格式真麻烦!这辈子应该也不会再跟它打交道了,直接上代码:
from PIL import Image import numpy as np from tqdm import tqdm def convert(image_path, label_path, n): f_images = open(image_path[0], 'rb') f_labels = open(label_path[0], 'rb') f_out = open(label_path[1], 'w') # 标签路径 f_images.read(16) f_labels.read(1) images = [] labels = [] for i in range(n): image = [] labels.append(ord(f_labels.read(1))) for j in range(28*28): image.append(ord(f_images.read(1))) images.append(image) idx = 0 for image in tqdm(images): img = Image.fromarray(np.array(image).reshape((28, 28))).convert('L') if idx >= 10000: img.save(image_path[1] + '00' + str(idx) + '.png') elif idx >= 1000: img.save(image_path[1] + '000' + str(idx) + '.png') elif idx >= 100: img.save(image_path[1] + '0000' + str(idx) + '.png') elif idx >= 10: img.save(image_path[1] + '00000' + str(idx) + '.png') else: img.save(image_path[1] + '000000' + str(idx) + '.png') idx += 1 for label in labels: f_out.write(str(label) + '\n') f_images.close() f_labels.close() f_out.close() train_image_path = ['train-images.idx3-ubyte', 'dataset\\train\\'] # 训练集图片读取路径;训练集图片保存路径 train_label_path = ['train-labels.idx1-ubyte', 'dataset\\train.txt'] # 训练集标签读取路径;训练集标签保存路径 test_image_path = ['t10k-images.idx3-ubyte', 'dataset\\test\\'] # 测试集图片读取路径;测试集图片保存路径 test_label_path = ['t10k-labels.idx1-ubyte', 'dataset\\test.txt'] # 测试集标签读取路径;测试集标签保存路径 convert(train_image_path, train_label_path, 60000) print('Generate the train sets done!') convert(test_image_path, test_label_path, 10000) print('Generate the test sets done!')
csv格式数据下载地址:
https://pjreddie.com/projects/mnist-in-csv/
将train set与test set下载。csv文件中的图片格式为:
label, pix-11, pix-12, pix-13, …
第一列为图片的标签。其中pix-ij为第i行第j列的像素值。
我们从csv文件中读取label并保存为txt文件, 读取像素值转化为png图片并保存。
代码如下:
import numpy import numpy as np from tqdm import tqdm from PIL import Image ''' 标签直接保存到txt文件中 图片先reshape成28x28大小,再将其名称长度统一为7,方便后续按顺序读取 名称格式为00....+idx.png idx表示为第idx张图片,前面的0的数量为使其长度为7时所需的数量 ''' def convert(mnist, save_path): idx = 0 f_tr_label = open(r'dataset\\train_label.txt', 'w') # 保存train标签文件 f_test_label = open(r'dataset\\test_label.txt', 'w') # 保存test标签文件 f_label = [f_tr_label, f_test_label] for i in range(len(mnist)): mnist_docs = open(mnist[i], 'r').readlines() for mnist_doc in tqdm(mnist_docs): mnist_doc = mnist_doc.strip().split(',') # csv文件用逗号分隔数据,因此利用split()来以逗号划分成列表 f_label[i].write(mnist_doc[0] + '\n') ''' 由于读取的pixel是字符型,需要将其转为整型 利用map()函数来进行实现 map()的语法规则为 map(function, iterable, ...) --function为一个函数 --iterable为一个或多个序列 因此使用map(int list_name)就可以将一个列表中的元素转为整型 但是其返回的是迭代器,因此还需要用list()将其再转化为列表 ''' # print(mnist_doc[1:]) img = list(map(int, mnist_doc[1:])) img = np.array(img).reshape((28, 28)) # mnist图片尺寸为28 x 28的 img = Image.fromarray(img).convert('L') # 将numpy矩阵转为image格式并且转化为灰度图,方便保存图片 if idx >= 10000: img.save(save_path[i] + '00' + str(idx) + '.png') elif idx >= 1000: img.save(save_path[i] + '000' + str(idx) + '.png') elif idx >= 100: img.save(save_path[i] + '0000' + str(idx) + '.png') elif idx >= 10: img.save(save_path[i] + '00000' + str(idx) + '.png') else: img.save(save_path[i] + '000000' + str(idx) + '.png') idx += 1 print('done!') mnist = ['mnist_train.csv', 'mnist_test.csv'] save_path = [r'dataset\\train\\', r'dataset\\test\\'] convert(mnist, save_path)
无注释纯净版!
在获取到mnist的图片与标签之后,我们需要对数据集进行一个预处理,来获取数据集的图片路径,方便后续进行数据加载。同时需要将训练集划分为训练部分和验证部分,每一轮训练结束后用验证部分来验证模型性能。
首先在根目录下创建mnist_annotation.py文件,这个文件运行结束后会在dataset目录下生成
train.txt # 保存训练图片路径
test.txt # 保存测试图片路径
val.txt # 保存验证图片路径
train_label.txt # 保存训练标签
test_label.txt # 保存测试标签
val_label.txt # 保存验证标签
代码如下:
import os from tqdm import tqdm from typing import List # 训练集与验证集比例 = 9 : 1 # 由于训练集与测试集已经区分开,所以不需要再划分 trainval_percent = 0.9 def split_sets(train_path, test_path): f_train = open(r'dataset\train.txt', 'w') f_test = open(r'dataset\test.txt', 'w') f_val_label = open(r'dataset\val_label.txt', 'w') f_train_label = open(r'dataset\tr_label.txt', 'w') # 将 dataset\\test 下的图片用os.listdir()函数获取到文件路径并写入txt文件 test_pic_paths = os.listdir(test_path) for test_pic_path in tqdm(test_pic_paths): f_test.write(test_pic_path + '\n') print('Generate test.txt done!') # 将 dataset\\train 下的图片与标签划分为训练与验证,并依次存入txt文件 train_pic_path = os.listdir(train_path[0]) num = len(train_pic_path) train_num = int(num * trainval_percent) for i in tqdm(train_pic_path[0:train_num]): f_train.write(i + '\n') print('Generate train.txt done!') with open(train_path[1], 'r+') as f: train_labels = f.readlines() for train_label in tqdm(train_labels[0:train_num]): f_train_label.write(train_label) print('Generate tr_label.txt done!') for val_label in train_labels[train_num:]: f_val_label.write(val_label) print('Generate val_label.txt done!') f.close() f_train.close() f_test.close() f_val_label.close() f_train_label.close() train_path = ['dataset\\train', 'dataset\\train_label.txt'] test_path = 'dataset\\test' split_sets(train_path, test_path)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。