赞
踩
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
我们提供包含前1000个训练图像和5个随机测试图像的数据集的小规模样本
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
'2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
demo = True
if demo:
data_dir = d2l.download_extract('cifar10_tiny')
else:
data_dir = '../data/cifar-10/'
Downloading ../data/kaggle_cifar10_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_cifar10_tiny.zip...
整理数据集
def read_csv_labels(fname):
"""读取‘fname’来给标签字典返回一个文件名"""
with open(fname, 'r') as f:
lines = f.readlines()[1:]
tokens = [l.rstrip().split(',') for l in lines]
return dict(((name, label) for name, label in tokens))
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
labels
{'1': 'frog', '2': 'truck', '3': 'truck', '4': 'deer', '5': 'automobile', '6': 'automobile', '7': 'bird', '8': 'horse', '9': 'ship', '10': 'cat', '11': 'deer', '12': 'horse', '13': 'horse', '14': 'bird', '15': 'truck', '16': 'truck', '17': 'truck', '18': 'cat', '19': 'bird', '20': 'frog', '21': 'deer', '22': 'cat', '23': 'frog', '24': 'frog', '25': 'bird', '26': 'frog', '27': 'cat', '28': 'dog', '29': 'deer', '30': 'airplane', '31': 'airplane', '32': 'truck', '33': 'automobile', '34': 'cat', '35': 'deer', '36': 'airplane', '37': 'cat', '38': 'horse', '39': 'cat', '40': 'cat', '41': 'dog', '42': 'bird', '43': 'bird', '44': 'horse', '45': 'automobile', '46': 'automobile', '47': 'automobile', '48': 'bird', '49': 'bird', '50': 'airplane', '51': 'truck', '52': 'dog', '53': 'horse', '54': 'truck', '55': 'bird', '56': 'bird', '57': 'dog', '58': 'bird', '59': 'deer', '60': 'cat', '61': 'automobile', '62': 'automobile', '63': 'ship', '64': 'bird', '65': 'automobile', '66': 'automobile', '67': 'deer', '68': 'truck', '69': 'horse', '70': 'ship', '71': 'dog', '72': 'truck', '73': 'frog', '74': 'horse', '75': 'cat', '76': 'automobile', '77': 'truck', '78': 'airplane', '79': 'cat', '80': 'automobile', '81': 'cat', '82': 'dog', '83': 'deer', '84': 'dog', '85': 'horse', '86': 'horse', '87': 'deer', '88': 'horse', '89': 'truck', '90': 'deer', '91': 'bird', '92': 'cat', '93': 'ship', '94': 'airplane', '95': 'automobile', '96': 'frog', '97': 'automobile', '98': 'automobile', '99': 'deer', '100': 'automobile', '101': 'ship', '102': 'cat', '103': 'truck', '104': 'frog', '105': 'frog', '106': 'automobile', '107': 'ship', '108': 'dog', '109': 'bird', '110': 'truck', '111': 'truck', '112': 'ship', '113': 'automobile', '114': 'horse', '115': 'horse', '116': 'airplane', '117': 'airplane', '118': 'frog', '119': 'truck', '120': 'automobile', '121': 'bird', '122': 'bird', '123': 'truck', '124': 'bird', '125': 'frog', '126': 'frog', '127': 'automobile', '128': 'truck', '129': 'dog', '130': 'airplane', '131': 'deer', '132': 'horse', '133': 'frog', '134': 'horse', '135': 'automobile', '136': 'ship', '137': 'automobile', '138': 'automobile', '139': 'bird', '140': 'ship', '141': 'automobile', '142': 'cat', '143': 'cat', '144': 'frog', '145': 'bird', '146': 'deer', '147': 'truck', '148': 'truck', '149': 'dog', '150': 'deer', '151': 'cat', '152': 'frog', '153': 'horse', '154': 'deer', '155': 'frog', '156': 'ship', '157': 'dog', '158': 'dog', '159': 'deer', '160': 'cat', '161': 'automobile', '162': 'ship', '163': 'deer', '164': 'horse', '165': 'frog', '166': 'airplane', '167': 'truck', '168': 'dog', '169': 'automobile', '170': 'cat', '171': 'ship', '172': 'bird', '173': 'horse', '174': 'dog', '175': 'cat', '176': 'deer', '177': 'automobile', '178': 'dog', '179': 'horse', '180': 'airplane', '181': 'deer', '182': 'horse', '183': 'dog', '184': 'dog', '185': 'automobile', '186': 'airplane', '187': 'truck', '188': 'frog', '189': 'truck', '190': 'airplane', '191': 'ship', '192': 'horse', '193': 'ship', '194': 'ship', '195': 'bird', '196': 'dog', '197': 'bird', '198': 'cat', '199': 'dog', '200': 'airplane', '201': 'frog', '202': 'automobile', '203': 'truck', '204': 'cat', '205': 'frog', '206': 'truck', '207': 'automobile', '208': 'cat', '209': 'truck', '210': 'frog', '211': 'frog', '212': 'horse', '213': 'automobile', '214': 'airplane', '215': 'truck', '216': 'dog', '217': 'ship', '218': 'dog', '219': 'bird', '220': 'truck', '221': 'airplane', '222': 'ship', '223': 'ship', '224': 'airplane', '225': 'frog', '226': 'truck', '227': 'automobile', '228': 'automobile', '229': 'frog', '230': 'cat', '231': 'horse', '232': 'frog', '233': 'frog', '234': 'airplane', '235': 'frog', '236': 'frog', '237': 'automobile', '238': 'horse', '239': 'automobile', '240': 'dog', '241': 'ship', '242': 'cat', '243': 'frog', '244': 'frog', '245': 'ship', '246': 'frog', '247': 'ship', '248': 'deer', '249': 'frog', '250': 'frog', '251': 'automobile', '252': 'cat', '253': 'ship', '254': 'cat', '255': 'deer', '256': 'automobile', '257': 'horse', '258': 'automobile', '259': 'cat', '260': 'ship', '261': 'dog', '262': 'automobile', '263': 'automobile', '264': 'deer', '265': 'airplane', '266': 'truck', '267': 'cat', '268': 'horse', '269': 'deer', '270': 'truck', '271': 'truck', '272': 'bird', '273': 'deer', '274': 'truck', '275': 'truck', '276': 'automobile', '277': 'airplane', '278': 'dog', '279': 'truck', '280': 'airplane', '281': 'ship', '282': 'bird', '283': 'automobile', '284': 'bird', '285': 'airplane', '286': 'dog', '287': 'frog', '288': 'cat', '289': 'bird', '290': 'horse', '291': 'ship', '292': 'ship', '293': 'frog', '294': 'airplane', '295': 'horse', '296': 'truck', '297': 'deer', '298': 'dog', '299': 'frog', '300': 'deer', '301': 'bird', '302': 'automobile', '303': 'automobile', '304': 'bird', '305': 'automobile', '306': 'dog', '307': 'truck', '308': 'truck', '309': 'airplane', '310': 'ship', '311': 'deer', '312': 'automobile', '313': 'automobile', '314': 'frog', '315': 'cat', '316': 'cat', '317': 'truck', '318': 'airplane', '319': 'horse', '320': 'truck', '321': 'horse', '322': 'horse', '323': 'truck', '324': 'automobile', '325': 'dog', '326': 'automobile', '327': 'frog', '328': 'frog', '329': 'ship', '330': 'horse', '331': 'automobile', '332': 'cat', '333': 'airplane', '334': 'cat', '335': 'cat', '336': 'bird', '337': 'deer', '338': 'dog', '339': 'horse', '340': 'dog', '341': 'truck', '342': 'airplane', '343': 'cat', '344': 'deer', '345': 'airplane', '346': 'deer', '347': 'deer', '348': 'frog', '349': 'airplane', '350': 'airplane', '351': 'frog', '352': 'frog', '353': 'airplane', '354': 'ship', '355': 'automobile', '356': 'frog', '357': 'bird', '358': 'truck', '359': 'bird', '360': 'dog', '361': 'truck', '362': 'frog', '363': 'horse', '364': 'deer', '365': 'automobile', '366': 'ship', '367': 'horse', '368': 'cat', '369': 'frog', '370': 'truck', '371': 'cat', '372': 'airplane', '373': 'deer', '374': 'airplane', '375': 'dog', '376': 'automobile', '377': 'airplane', '378': 'cat', '379': 'deer', '380': 'ship', '381': 'dog', '382': 'deer', '383': 'horse', '384': 'bird', '385': 'cat', '386': 'truck', '387': 'horse', '388': 'frog', '389': 'horse', '390': 'automobile', '391': 'deer', '392': 'horse', '393': 'airplane', '394': 'automobile', '395': 'horse', '396': 'cat', '397': 'automobile', '398': 'ship', '399': 'deer', '400': 'deer', '401': 'bird', '402': 'airplane', '403': 'bird', '404': 'bird', '405': 'airplane', '406': 'airplane', '407': 'truck', '408': 'airplane', '409': 'truck', '410': 'frog', '411': 'ship', '412': 'bird', '413': 'horse', '414': 'horse', '415': 'deer', '416': 'airplane', '417': 'cat', '418': 'airplane', '419': 'ship', '420': 'truck', '421': 'deer', '422': 'bird', '423': 'horse', '424': 'bird', '425': 'dog', '426': 'bird', '427': 'dog', '428': 'automobile', '429': 'truck', '430': 'deer', '431': 'ship', '432': 'dog', '433': 'automobile', '434': 'horse', '435': 'deer', '436': 'deer', '437': 'airplane', '438': 'frog', '439': 'truck', '440': 'airplane', '441': 'horse', '442': 'ship', '443': 'ship', '444': 'truck', '445': 'truck', '446': 'cat', '447': 'cat', '448': 'deer', '449': 'airplane', '450': 'deer', '451': 'dog', '452': 'frog', '453': 'frog', '454': 'airplane', '455': 'automobile', '456': 'airplane', '457': 'ship', '458': 'airplane', '459': 'deer', '460': 'ship', '461': 'ship', '462': 'automobile', '463': 'dog', '464': 'bird', '465': 'frog', '466': 'ship', '467': 'automobile', '468': 'airplane', '469': 'airplane', '470': 'horse', '471': 'horse', '472': 'dog', '473': 'truck', '474': 'frog', '475': 'bird', '476': 'ship', '477': 'cat', '478': 'deer', '479': 'horse', '480': 'cat', '481': 'truck', '482': 'airplane', '483': 'automobile', '484': 'bird', '485': 'deer', '486': 'ship', '487': 'automobile', '488': 'ship', '489': 'frog', '490': 'deer', '491': 'deer', '492': 'dog', '493': 'horse', '494': 'automobile', '495': 'cat', '496': 'truck', '497': 'ship', '498': 'airplane', '499': 'automobile', '500': 'horse', '501': 'dog', '502': 'ship', '503': 'bird', '504': 'ship', '505': 'airplane', '506': 'deer', '507': 'automobile', '508': 'ship', '509': 'truck', '510': 'ship', '511': 'bird', '512': 'truck', '513': 'truck', '514': 'bird', '515': 'horse', '516': 'dog', '517': 'horse', '518': 'cat', '519': 'ship', '520': 'ship', '521': 'deer', '522': 'deer', '523': 'bird', '524': 'horse', '525': 'automobile', '526': 'frog', '527': 'deer', '528': 'airplane', '529': 'deer', '530': 'frog', '531': 'truck', '532': 'horse', '533': 'frog', '534': 'bird', '535': 'dog', '536': 'dog', '537': 'automobile', '538': 'horse', '539': 'bird', '540': 'bird', '541': 'bird', '542': 'truck', '543': 'dog', '544': 'deer', '545': 'bird', '546': 'horse', '547': 'ship', '548': 'automobile', '549': 'cat', '550': 'deer', '551': 'cat', '552': 'horse', '553': 'frog', '554': 'truck', '555': 'ship', '556': 'airplane', '557': 'frog', '558': 'airplane', '559': 'bird', '560': 'bird', '561': 'bird', '562': 'automobile', '563': 'ship', '564': 'deer', '565': 'airplane', '566': 'automobile', '567': 'ship', '568': 'ship', '569': 'automobile', '570': 'dog', '571': 'horse', '572': 'frog', '573': 'deer', '574': 'dog', '575': 'ship', '576': 'horse', '577': 'automobile', '578': 'truck', '579': 'automobile', '580': 'truck', '581': 'ship', '582': 'deer', '583': 'horse', '584': 'cat', '585': 'ship', '586': 'ship', '587': 'bird', '588': 'frog', '589': 'frog', '590': 'horse', '591': 'automobile', '592': 'frog', '593': 'ship', '594': 'automobile', '595': 'truck', '596': 'horse', '597': 'ship', '598': 'cat', '599': 'airplane', '600': 'automobile', '601': 'airplane', '602': 'ship', '603': 'ship', '604': 'cat', '605': 'airplane', '606': 'airplane', '607': 'automobile', '608': 'dog', '609': 'airplane', '610': 'ship', '611': 'ship', '612': 'horse', '613': 'truck', '614': 'truck', '615': 'airplane', '616': 'truck', '617': 'deer', '618': 'automobile', '619': 'cat', '620': 'frog', '621': 'frog', '622': 'deer', '623': 'deer', '624': 'horse', '625': 'dog', '626': 'frog', '627': 'airplane', '628': 'ship', '629': 'airplane', '630': 'cat', '631': 'bird', '632': 'ship', '633': 'deer', '634': 'frog', '635': 'truck', '636': 'truck', '637': 'horse', '638': 'airplane', '639': 'cat', '640': 'cat', '641': 'frog', '642': 'horse', '643': 'deer', '644': 'truck', '645': 'automobile', '646': 'frog', '647': 'bird', '648': 'horse', '649': 'bird', '650': 'bird', '651': 'airplane', '652': 'frog', '653': 'horse', '654': 'dog', '655': 'horse', '656': 'frog', '657': 'ship', '658': 'truck', '659': 'airplane', '660': 'truck', '661': 'deer', '662': 'deer', '663': 'horse', '664': 'airplane', '665': 'truck', '666': 'deer', '667': 'truck', '668': 'frog', '669': 'truck', '670': 'deer', '671': 'dog', '672': 'horse', '673': 'truck', '674': 'bird', '675': 'deer', '676': 'dog', '677': 'automobile', '678': 'deer', '679': 'cat', '680': 'truck', '681': 'frog', '682': 'dog', '683': 'frog', '684': 'truck', '685': 'cat', '686': 'cat', '687': 'dog', '688': 'airplane', '689': 'horse', '690': 'bird', '691': 'automobile', '692': 'cat', '693': 'frog', '694': 'deer', '695': 'airplane', '696': 'airplane', '697': 'bird', '698': 'dog', '699': 'airplane', '700': 'automobile', '701': 'airplane', '702': 'bird', '703': 'cat', '704': 'truck', '705': 'ship', '706': 'deer', '707': 'truck', '708': 'ship', '709': 'airplane', '710': 'bird', '711': 'frog', '712': 'deer', '713': 'deer', '714': 'airplane', '715': 'automobile', '716': 'ship', '717': 'ship', '718': 'cat', '719': 'frog', '720': 'truck', '721': 'frog', '722': 'frog', '723': 'horse', '724': 'ship', '725': 'bird', '726': 'deer', '727': 'dog', '728': 'horse', '729': 'frog', '730': 'dog', '731': 'cat', '732': 'airplane', '733': 'dog', '734': 'airplane', '735': 'dog', '736': 'airplane', '737': 'ship', '738': 'bird', '739': 'frog', '740': 'horse', '741': 'cat', '742': 'ship', '743': 'bird', '744': 'automobile', '745': 'horse', '746': 'frog', '747': 'horse', '748': 'automobile', '749': 'airplane', '750': 'truck', '751': 'dog', '752': 'dog', '753': 'airplane', '754': 'automobile', '755': 'horse', '756': 'frog', '757': 'truck', '758': 'airplane', '759': 'deer', '760': 'horse', '761': 'horse', '762': 'automobile', '763': 'dog', '764': 'truck', '765': 'deer', '766': 'airplane', '767': 'ship', '768': 'dog', '769': 'truck', '770': 'truck', '771': 'frog', '772': 'horse', '773': 'automobile', '774': 'ship', '775': 'cat', '776': 'bird', '777': 'cat', '778': 'ship', '779': 'bird', '780': 'bird', '781': 'deer', '782': 'frog', '783': 'airplane', '784': 'airplane', '785': 'dog', '786': 'cat', '787': 'ship', '788': 'bird', '789': 'cat', '790': 'horse', '791': 'bird', '792': 'truck', '793': 'cat', '794': 'ship', '795': 'horse', '796': 'ship', '797': 'bird', '798': 'horse', '799': 'truck', '800': 'airplane', '801': 'bird', '802': 'cat', '803': 'bird', '804': 'bird', '805': 'bird', '806': 'cat', '807': 'cat', '808': 'frog', '809': 'bird', '810': 'cat', '811': 'bird', '812': 'ship', '813': 'airplane', '814': 'dog', '815': 'dog', '816': 'automobile', '817': 'deer', '818': 'dog', '819': 'frog', '820': 'frog', '821': 'bird', '822': 'horse', '823': 'airplane', '824': 'automobile', '825': 'horse', '826': 'horse', '827': 'ship', '828': 'bird', '829': 'truck', '830': 'bird', '831': 'bird', '832': 'deer', '833': 'bird', '834': 'automobile', '835': 'automobile', '836': 'automobile', '837': 'frog', '838': 'frog', '839': 'frog', '840': 'dog', '841': 'automobile', '842': 'automobile', '843': 'horse', '844': 'airplane', '845': 'deer', '846': 'cat', '847': 'cat', '848': 'horse', '849': 'automobile', '850': 'bird', '851': 'cat', '852': 'dog', '853': 'dog', '854': 'dog', '855': 'frog', '856': 'automobile', '857': 'deer', '858': 'cat', '859': 'horse', '860': 'ship', '861': 'ship', '862': 'cat', '863': 'frog', '864': 'frog', '865': 'bird', '866': 'cat', '867': 'airplane', '868': 'truck', '869': 'deer', '870': 'cat', '871': 'ship', '872': 'airplane', '873': 'airplane', '874': 'automobile', '875': 'automobile', '876': 'dog', '877': 'deer', '878': 'truck', '879': 'cat', '880': 'automobile', '881': 'ship', '882': 'truck', '883': 'cat', '884': 'truck', '885': 'truck', '886': 'bird', '887': 'truck', '888': 'deer', '889': 'ship', '890': 'bird', '891': 'truck', '892': 'ship', '893': 'ship', '894': 'automobile', '895': 'dog', '896': 'cat', '897': 'frog', '898': 'ship', '899': 'horse', '900': 'frog', '901': 'truck', '902': 'ship', '903': 'airplane', '904': 'frog', '905': 'deer', '906': 'airplane', '907': 'airplane', '908': 'bird', '909': 'dog', '910': 'ship', '911': 'bird', '912': 'airplane', '913': 'bird', '914': 'horse', '915': 'frog', '916': 'truck', '917': 'horse', '918': 'automobile', '919': 'dog', '920': 'dog', '921': 'frog', '922': 'frog', '923': 'cat', '924': 'frog', '925': 'bird', '926': 'deer', '927': 'horse', '928': 'airplane', '929': 'dog', '930': 'frog', '931': 'deer', '932': 'frog', '933': 'dog', '934': 'bird', '935': 'deer', '936': 'frog', '937': 'automobile', '938': 'frog', '939': 'airplane', '940': 'deer', '941': 'airplane', '942': 'cat', '943': 'automobile', '944': 'ship', '945': 'dog', '946': 'deer', '947': 'deer', '948': 'automobile', '949': 'horse', '950': 'cat', '951': 'truck', '952': 'deer', '953': 'horse', '954': 'truck', '955': 'horse', '956': 'cat', '957': 'horse', '958': 'bird', '959': 'ship', '960': 'deer', '961': 'frog', '962': 'frog', '963': 'automobile', '964': 'bird', '965': 'truck', '966': 'airplane', '967': 'deer', '968': 'ship', '969': 'horse', '970': 'cat', '971': 'truck', '972': 'ship', '973': 'horse', '974': 'horse', '975': 'airplane', '976': 'bird', '977': 'deer', '978': 'automobile', '979': 'automobile', '980': 'deer', '981': 'automobile', '982': 'dog', '983': 'deer', '984': 'airplane', '985': 'dog', '986': 'frog', '987': 'bird', '988': 'ship', '989': 'dog', '990': 'airplane', '991': 'bird', '992': 'automobile', '993': 'cat', '994': 'dog', '995': 'horse', '996': 'cat', '997': 'dog', '998': 'automobile', '999': 'cat', '1000': 'dog'}
将验证集从原始的训练集中拆分出来
# 在pytorch中有一个比较简单但很常用的加载数据的方式就是先将文件夹创建好,然后文件夹名字为label,然后将这个label的训练数据放进去 # 这个函数的作用就是创建子文件夹,然后将图片搬过去 def copyfile(filename, target_dir): """文件复制到目标目录""" os.makedirs(target_dir, exist_ok=True) shutil.copy(filename, target_dir) # 根目录:train_valid_test。下面有train文件夹,包含训练数据。valid包含验证数据,train_valid原始的train文件夹 def reorg_train_valid(data_dir, labels, valid_ratio): n = collections.Counter(labels.values()).most_common()[-1][1] n_valid_per_label = max(1, math.floor(n * valid_ratio)) label_count = {} for train_file in os.listdir(os.path.join(data_dir, 'train')): label = labels[train_file.split('.')[0]] fname = os.path.join(data_dir, 'train', train_file) copyfile( fname, os.path.join(data_dir, 'train_valid_set', 'train_valid', label)) if label not in label_count or label_count[label] < n_valid_per_label: copyfile( fname, os.path.join(data_dir, 'train_valid_test', 'valid', label)) else: copyfile( fname, os.path.join(data_dir, 'train_valid_test', 'train', label)) return n_valid_per_label
在预测期间整理测试集,以方便读取
def reorg_test(data_dir):
for test_file in os.listdir(os.path.join(data_dir, 'test')):
copyfile(
os.path.join(data_dir, 'test', test_file),
os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
调用前面定义的函数
def reorg_cifar10_data(data_dir, valid_ratio):
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
reorg_train_valid(data_dir, labels, valid_ratio)
reorg_test(data_dir)
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
图像增广
transform_train = torchvision.transforms.Compose([
torchvision.transforms.Resize(40), # 将图片放大到40*40
torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
ratio=(1.0, 1.0)), # 随机裁剪
torchvision.transforms.RandomHorizontalFlip(), # 水平调整
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])]) # 对RGB三个channel
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])])
读取由原始图像组成的数据集
train_ds, train_valid_ds = [
torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_train) for folder in ['train', 'train_valid']]
valid_ds, test_ds = [
torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_test) for folder in ['valid', 'test']]
指定上面定义的所有图像增广操作
train_iter, train_valid_iter = [
torch.utils.data.DataLoader(dataset, batch_size, shuffle=True,
drop_last=True) # drop_last表示如果最后一个批量大小不够的话,就丢掉
for dataset in (train_ds, train_valid_ds)]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,
drop_last=True)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,
drop_last=False)
模型
def get_net():
num_classes = 10
net = d2l.resnet18(num_classes, 3) # 3的意思就是RGB三通道
return net
loss = nn.CrossEntropyLoss(reduction="none") # reduction=‘none’表示不要加起来
训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay): # lr_period,lr_decay,这里的意思就是每隔几次迭代学习率降低 trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd) scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay) # 这个函数的作用就是每个多少个迭代,将学习率乘以lr_decay num_batches, timer = len(train_iter), d2l.Timer() legend = ['train loss', 'train acc'] if valid_iter is not None: legend.append('valid acc') animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend) net = nn.DataParallel(net, device_ids=devices).to(devices[0]) for epoch in range(num_epochs): net.train() metric = d2l.Accumulator(3) for i, (features, labels) in enumerate(train_iter): timer.start() l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices) metric.add(l, acc, labels.shape[0]) timer.stop() if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add( epoch + (i + 1) / num_batches, (metric[0] / metric[2], metric[1] / metric[2], None)) if valid_iter is not None: valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter) animator.add(epoch + 1, (None, None, valid_acc)) scheduler.step() # 这里scheduler的原因就是之前已经将trainer给了schdeuler measures = (f'train loss {metric[0] / metric[2]:.3f}, ' f'train acc {metric[1] / metric[2]:.3f}') if valid_iter is not None: measures += f', valid acc {valid_acc:.3f}' print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}' f' examples/sec on {str(devices)}')
训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,
lr_decay)
对测试集进行分类并提交结果
net, preds = get_net(), []
train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,
lr_decay)
for X, _ in test_iter:
y_hat = net(X.to(devices[0]))
preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id': sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
我们提供完整数据集的小规模样本
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',
'0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')
demo = True
if demo:
data_dir = d2l.download_extract('dog_tiny')
else:
data_dir = os.path.join('..', 'data', 'dog-breed-identification')
Downloading ../data/kaggle_dog_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_dog_tiny.zip...
整理数据
def reorg_dog_data(data_dir, valid_ratio):
labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))
d2l.reorg_train_valid(data_dir, labels, valid_ratio)
d2l.reorg_test(data_dir)
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)
图片增广
transform_train = torchvision.transforms.Compose([ torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), # 图像的明亮度等等 torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) transform_test = torchvision.transforms.Compose([ torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), # 从中心点copy一个224*224的图片 torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
train_ds, train_valid_ds = [ torchvision.datasets.ImageFolder( os.path.join(data_dir, 'train_valid_test', folder), transform=transform_train) for folder in ['train', 'train_valid']] valid_ds, test_ds = [ torchvision.datasets.ImageFolder( os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test) for folder in ['valid', 'test']] train_iter, train_valid_iter = [ torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)] valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True) test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=False)
微调预训练模型
# 这个函数的作用就是将除了最后一层以外的参数不变,拿过来
def get_net(devices):
finetune_net = nn.Sequential()
finetune_net.features = torchvision.models.resnet34(pretrained=True)
finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256), nn.ReLU(),
nn.Linear(256, 120)) # 这里我们在原来的输出后面加了几层
finetune_net = finetune_net.to(devices[0])
for param in finetune_net.features.parameters():
param.requires_grad = False # 这里的意思就是将卷积层的参数固定住,不更新他了,所以设置为false
return finetune_net
计算损失
loss = nn.CrossEntropyLoss(reduction='none')
def evaluate_loss(data_iter, net, devices):
l_sum, n = 0.0, 0
for features, labels in data_iter:
features, labels = features.to(devices[0]), labels.to(devices[0])
outputs = net(features)
l = loss(outputs, labels)
l_sum += l.sum()
n += labels.numel()
return l_sum / n
训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay): net = nn.DataParallel(net, device_ids=devices).to(devices[0]) trainer = torch.optim.SGD( (param for param in net.parameters() if param.requires_grad), lr=lr, momentum=0.9, weight_decay=wd) # 这里第一个参数的肆意就是将网络中需要更新的参数给他,不需要更新的参数我们就不更新了 scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay) num_batches, timer = len(train_iter), d2l.Timer() legend = ['train loss'] if valid_iter is not None: legend.append('valid loss') animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend) for epoch in range(num_epochs): metric = d2l.Accumulator(2) for i, (features, labels) in enumerate(train_iter): timer.start() features, labels = features.to(devices[0]), labels.to(devices[0]) trainer.zero_grad() output = net(features) l = loss(output, labels).sum() l.backward() trainer.step() metric.add(l, labels.shape[0]) timer.stop() if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1], None)) measures = f'train loss {metric[0] / metric[1]:.3f}' if valid_iter is not None: valid_loss = evaluate_loss(valid_iter, net, devices) animator.add(epoch + 1, (None, valid_loss.detach())) scheduler.step() if valid_iter is not None: measures += f', valid loss {valid_loss:.3f}' print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}' f' examples/sec on {str(devices)}')
训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net = 2, 0.9, get_net(devices)
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,
lr_decay)
对测试集分类
net = get_net(devices) train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay) preds = [] for data, label in test_iter: output = torch.nn.functional.softmax(net(data.to(devices[0])), dim=0) preds.extend(output.cpu().detach().numpy()) ids = sorted( os.listdir(os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))) with open('submission.csv', 'w') as f: f.write('id,' + ','.join(train_valid_ds.classes) + '\n') for i, output in zip(ids, preds): f.write( i.split('.')[0] + ',' + ','.join([str(num) for num in output]) + '\n')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。