赞
踩
注:根目录是albert-pytorch项目根目录,来自github该repo
模型文件加载的文件跳转路径:
/run_classifier.py(387) AlbertForSequenceClassification.from_pretrained()->
/model/modeling_utils.py(191) from_pretrained() ->
/model/modeling_utils.py(363) load() ->
/model/modeling_utils.py(347) load() -> # 这是个递归函数,在一次次递归中"prefix"参数在变化,控制着模型参数的载入;
/model/modeling_utils.py(347) module._load_from_state_dict() ->
{$TORCH_HOME}/nn/modules/module.py(703) _load_from_state_dict()
重点就在这函数_load_from_state_dict()
里面。line742~line769的for-loop。若是成功的模型参数加载,则line762:param.copy_(input_param)
就会被执行(这段for-loop代码示例如下)
742 for name, param in local_state.items(): 743 key = prefix + name 744 if key in state_dict: 745 input_param = state_dict[key] 746 747 # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ 748 if len(param.shape) == 0 and len(input_param.shape) == 1: 749 input_param = input_param[0] 750 751 if input_param.shape != param.shape: 752 # local shape should match the one in checkpoint 753 error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 754 'the shape in current model is {}.' 755 .format(key, input_param.shape, param.shape)) 756 continue 757 758 if isinstance(input_param, Parameter): 759 # backwards compatibility for serialized parameters 760 input_param = input_param.data 761 try: 762 param.copy_(input_param) 763 except Exception: 764 error_msgs.append('While copying the parameter named "{}", ' 765 'whose dimensions in the model are {} and ' 766 'whose dimensions in the checkpoint are {}.' 767 .format(key, param.size(), input_param.size())) 768 elif strict: 769 missing_keys.append(key)
这里param是一个torch.Tensor
,让我们读一下torch.Tensor.copy_()
的函数文档
Copies the elements from
src
into self tensor and returnsself
. The
src
tensor must be broadcastable with theself
tensor. It may be of a
different data type or reside on a different device.
很简单,意思就是,src
=input_param
会被复制到self
当中(当前self
就是当前param所在的nn.Module),同时input_param
会作为返回值。
以参数bert.embeddings.word_embeddings.weight为例子,此时该param所对应的“self
”是Embedding(21128, 128, padding_idx=0),所在层就是Embedding-layer。我们从class torch.nn.Embedding可以看出(下面附source code),该层含有num_embeddings、embedding_dim等属性。它们分别就是21128, 128(前者是vocab词表大小,后者是albert的Embedding size)
class Embedding(Module): def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, _weight=None): super(Embedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq if _weight is None: self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() else: assert list(_weight.shape) == [num_embeddings, embedding_dim], \ 'Shape of weight does not match num_embeddings and embedding_dim' self.weight = Parameter(_weight) self.sparse = sparse
注意line 18~20。当还未执行param.copy_(input_param)
时,此时Embedding层的参数还是未初始化的,其中line20的self.reset_parameters()
内部的操作就是,将前面指定维度生成的Tensor,填以服从N(0, 1)的正态分布的随机数。我们此时先把这个self.weight打出来看看:
ipdb> self.weight
Parameter containing:
tensor([[-0.0072, -0.0040, 0.0490, ..., -0.0219, 0.0050, -0.0293],
[ 0.0405, 0.0166, -0.0039, ..., -0.0099, -0.0004, -0.0137],
[ 0.0111, -0.0048, 0.0283, ..., 0.0047, -0.0072, 0.0130],
...,
[ 0.0209, -0.0084, -0.0283, ..., 0.0367, 0.0080, -0.0220],
[ 0.0584, 0.0286, 0.0028, ..., -0.0016, 0.0436, 0.0071],
[ 0.0238, -0.0204, 0.0172, ..., -0.0435, -0.0267, 0.0099]],
requires_grad=True)
执行过param.copy_(input_param)
后,看看self.weight
是否被修改成bert.embeddings.word_embeddings.weight的内容了:
ipdb> self.weight
Parameter containing:
tensor([[ 0.0722, 0.0224, 0.1045, ..., 0.0800, 0.0776, -0.0483],
[ 0.0779, 0.0606, 0.0891, ..., 0.0628, 0.0831, -0.0924],
[ 0.0891, 0.0782, 0.0731, ..., 0.0609, 0.1201, -0.0561],
...,
[ 0.0159, 0.0438, 0.1095, ..., 0.0802, 0.0773, -0.0790],
[ 0.0664, 0.0513, 0.1075, ..., 0.0682, 0.0776, -0.0842],
[ 0.0135, 0.0239, 0.1113, ..., 0.0646, 0.0756, -0.0632]],
requires_grad=True)
这恰好就是bert.embeddings.word_embeddings.weight对应的值(input_param
):
ipdb> (input_param == self.weight).numpy().all()
True
这说明了line762的param.copy_(input_param)
就是在将input_param
更新到self.weight
上去。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。