赞
踩
p={‘trainBatch’:6, 'nAveGrad':1, 'lr':1e-07, 'wd':0.0005, 'momentum':0.9,'epoch_size':10, 'optimizer':'SGD()'}最后一个optimizer的值是很长的字符串就不全部写出来了。这个字典长度是7。
其中的net 和criterion在稍后来进行讲解
if resume_epoch==0,那么从头开始训练 training from scratch;否则权重的初始化时一个已经训练好的模型,使用net.load_state_dict函数,这个函数是在torch.nn.Module类里面定义的一个函数。
- def load_state_dict(self, state_dict, strict=True):
- r"""Copies parameters and buffers from :attr:`state_dict` into
- this module and its descendants. If :attr:`strict` is ``True``, then
- the keys of :attr:`state_dict` must exactly match the keys returned
- by this module's :meth:`~torch.nn.Module.state_dict` function.
- Arguments:
- state_dict (dict): a dict containing parameters and
- persistent buffers.
- strict (bool, optional): whether to strictly enforce that the keys
- in :attr:`state_dict` match the keys returned by this module's
- :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
- """
- missing_keys = []
- unexpected_keys = []
- error_msgs = []
-
- # copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, '_metadata', None)
- state_dict = state_dict.copy()
- if metadata is not None:
- state_dict._metadata = metadata
-
- def load(module, prefix=''):
- module._load_from_state_dict(
- state_dict, prefix, strict, missing_keys, unexpected_keys, error_msgs)
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + '.')
-
- load(self)
而里面的torch.load函数定义如下.map_location参数有三种形式:函数,字符串,字典
- def load(f, map_location=None, pickle_module=pickle):
- """Loads an object saved with :func:`torch.save` from a file.
- :meth:`torch.load` uses Python's unpickling facilities but treats storages,
- which underlie tensors, specially. They are first deserialized on the
- CPU and are then moved to the device they were saved from. If this fails
- (e.g. because the run time system doesn't have certain devices), an exception
- is raised. However, storages can be dynamically remapped to an alternative
- set of devices using the `map_location` argument.
- If `map_location` is a callable, it will be called once for each serialized
- storage with two arguments: storage and location. The storage argument
- will be the initial deserialization of the storage, residing on the CPU.
- Each serialized storage has a location tag associated with it which
- identifies the device it was saved from, and this tag is the second
- argument passed to map_location. The builtin location tags are `'cpu'` for
- CPU tensors and `'cuda:device_id'` (e.g. `'cuda:2'`) for CUDA tensors.
- `map_location` should return either None or a storage. If `map_location` returns
- a storage, it will be used as the final deserialized object, already moved to
- the right device. Otherwise, :math:`torch.load` will fall back to the default
- behavior, as if `map_location` wasn't specified.
- If `map_location` is a string, it should be a device tag, where all tensors
- should be loaded.
- Otherwise, if `map_location` is a dict, it will be used to remap location tags
- appearing in the file (keys), to ones that specify where to put the
- storages (values).
- User extensions can register their own location tags and tagging and
- deserialization methods using `register_package`.
- Args:
- f: a file-like object (has to implement read, readline, tell, and seek),
- or a string containing a file name
- map_location: a function, string or a dict specifying how to remap storage
- locations
- pickle_module: module used for unpickling metadata and objects (has to
- match the pickle_module used to serialize file)
- Example:
- >>> torch.load('tensors.pt')
- # Load all tensors onto the CPU
- >>> torch.load('tensors.pt', map_location='cpu')
- # Load all tensors onto the CPU, using a function
- >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
- # Load all tensors onto GPU 1
- >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
- # Map tensors from GPU 1 to GPU 0
- >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
- # Load tensor from io.BytesIO object
- >>> with open('tensor.pt') as f:
- buffer = io.BytesIO(f.read())
- >>> torch.load(buffer)
- """
设置使用GPU,这里是
torch.cuda.set_device(device=0) 告诉编码器cuda使用gpu0号
net.cuda() 将模型放在gpu0号上面
关于writer = SummaryWriter(log_dir=log_dir)这个函数在后面会讲解
num_img_tr = len(trainloader)# 1764 num_img_ts = len(testloader)# 242 这是batch数目
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。