赞
踩
之前一直不太搞明白浅拷贝和赋值、深拷贝到底有什么区别,直到被pytorch的model.state_dict()给坑了
今天在和实验室同学讨论联邦学习框架代码的时候,终于明白了他们之间的区别,这里做个记录。
(1)直接赋值:给变量取个别名,原来叫张三,现在我给他取个小名,叫小张
(2)浅拷贝(shadow copy):拷贝最外层的数值和指针,不拷贝更深层次的对象,即只拷贝了父对象
model.state_dict()
也是浅拷贝,如果令param=model.state_dict(),那么当你修改param,相应地也会修改model的参数。model这个对象实际上是指向各个参数矩阵的,而浅拷贝只会拷贝最外层的这些“指针”。具体可以看下文的示例题外话:浅拷贝为什么叫“浅”,因为他只拷贝最外层的东西,不会去拷贝最外层“指针”所指向的内层的东西,所以浅。而深拷贝则会拷贝全部层的东西,所以深
(3)深拷贝(deepcopy):拷贝数值、指针和指针指向的深层次内存空间,拷贝了父对象及其子对象。
model.load_state_dict(xxx)
是深拷贝这一小节主要来自:一个工作三年的同事,居然还搞不清深拷贝、浅拷贝…
2021年10月24日 更新:下面这个图其实是以Java语言而言的,我一开始以为Python字符串和int数值应该也是直接赋值的,后来经过验证,发现python中的字符串其实是引用(地址),所以若a=“hello”,则b=a是把"hello"的地址赋值给b。另外-5到256这个范围内的整数是公用一块内存空间的,具体请看我的博客:Python中容易被忽视的知识点:字符串是传引用以及整数-5到256共享内存空间
浅拷贝
深拷贝
深拷贝相较于上面所示的浅拷贝,除了值类型字段会复制一份,引用类型字段所指向的对象,会在内存中也创建一个副本,就像这个样子:
pytorch在获取模型参数和加载模型参数时是有坑点的,而且这个bug一般不太容易发现,因为他不会报错,有时你很难通过实验结果注意到这个问题,我自己写框架时也是被坑过。
model.state_dict()
实际上是浅拷贝,如果令param=model.state_dict(),那么当你修改param,相应地也会修改model的参数。model这个对象实际上是指向各个参数矩阵的,而浅拷贝只会拷贝最外层的这些“指针”。model.load_state_dict(xxx)
是深拷贝用代码验证以上观点,可以结合上文的两张示意图来理解下面代码
import torch
import copy
m1 = torch.nn.Linear(in_features=5, out_features=1, bias=True)
m2 = torch.nn.Linear(in_features=5, out_features=1, bias=True)
# m1是引用指向某块内存空间
# 浅拷贝相当于拷贝一个引用,所以他们“引用”变量的id是不一样的,指向的内存空间是一样的
ck = copy.copy(m1)
print(id(m1) == id(ck)) # False
print(m1.weight)
# Parameter containing:
# tensor([[ 0.0171, 0.4382, -0.4297, 0.4098, -0.3954]], requires_grad=True)
# state_dict is shadow copy
p = m1.state_dict()
print(id(m1.state_dict()) == id(p)) # False
# 通过引用p去修改内存空间
p['weight'][0][0] = 8.8888
# 可以看到m1指向的内存空间也被修改了
print(m1.state_dict())
# OrderedDict([('weight', tensor([[ 8.8888, 0.4382, -0.4297, 0.4098, -0.3954]])), ('bias', tensor([0.3964]))])
# deepcopy
m2.load_state_dict(p)
m2.weight[0][0] = 2.0
print(p)
# OrderedDict([('weight', tensor([[ 8.8888, 0.4382, -0.4297, 0.4098, -0.3954]])), ('bias', tensor([0.3964]))])
print(m2.state_dict())
# OrderedDict([('weight', tensor([[ 2.0000, 0.4382, -0.4297, 0.4098, -0.3954]])), ('bias', tensor([0.3964]))])
在我的联邦学习框架中本地模型参数确实是浅拷贝,但是我们没有去修改这个local_params,我们只是把不同客户端的local_params加权平均去更新global_params而已,所以不用deepcopy也没事
但如果想保存最优模型的参数,则必须要用deepcopy
best_state changes with the model during training in pytorch 这位提问者想保存最佳模型参数,结果因为浅拷贝,导致保存的都是最后一轮的模型参数,下面是他的错误代码:
def train():
#training steps …
if acc > best_acc:
best_state = model.state_dict()
best_acc = acc
return best_state
import copy
a = [1, 2, 3, 4, ['a', 'b']] # 原始对象
b = a # 赋值,传对象的引用
c = copy.copy(a) # 对象拷贝,浅拷贝
d = copy.deepcopy(a) # 对象拷贝,深拷贝
a.append(5) # 修改对象a
a[4].append('c') # a[4]是指针,修改对象a中的['a', 'b']数组对象
print('a = ', a)
print('b = ', b)
print('c = ', c) # 浅拷贝,只会拷贝最外层的数值或指针
print('d = ', d)
a = [1, 2, 3, 4, ['a', 'b', 'c'], 5]
b = [1, 2, 3, 4, ['a', 'b', 'c'], 5]
c = [1, 2, 3, 4, ['a', 'b', 'c']]
d = [1, 2, 3, 4, ['a', 'b']]
现在你看下面这段代码的输出结果应该就不奇怪了吧
import copy
A = [1, 2, 3]
print(A) # [1, 2, 3]
B = copy.copy(A) # 浅拷贝(最外层"值"会拷贝,"引用"会拷贝)
B.append(5)
print(A) # [1, 2, 3]
print(B) # [1, 2, 3, 5]
深拷贝示例:
# Python code to demonstrate copy operations
# importing "copy" for copy operations
import copy
# initializing list 1
li1 = [1, 2, [3, 5], 4]
# using deepcopy to deep copy
li2 = copy.deepcopy(li1)
# original elements of list
print("The original elements before deep copying")
for i in range(0, len(li1)):
print(li1[i], end=" ")
print("\r")
# adding and element to new list
li2[2][0] = 7
# Change is reflected in l2
print("The new list of elements after deep copying ")
for i in range(0, len(li1)):
print(li2[i], end=" ")
print("\r")
The original elements before deep copying
1 2 [3, 5] 4
The new list of elements after deep copying
1 2 [7, 5] 4
The original elements after deep copying
1 2 [3, 5] 4
浅拷贝示例:
# Python code to demonstrate copy operations
# importing "copy" for copy operations
import copy
# initializing list 1
li1 = [1, 2, [3,5], 4]
# using copy to shallow copy
li2 = copy.copy(li1)
# original elements of list
print ("The original elements before shallow copying")
for i in range(0,len(li1)):
print (li1[i],end=" ")
print("\r")
# adding and element to new list
li2[2][0] = 7
# checking if change is reflected
print ("The original elements after shallow copying")
for i in range(0,len( li1)):
print (li1[i],end=" ")
The original elements before shallow copying
1 2 [3, 5] 4
The original elements after shallow copying
1 2 [7, 5] 4
注意:上面用了li2[2][0] = 7
,相当于是在修改引用的内存空间;如果是li2[1] = 7
,那么l1[1]
不会改变
ii. copy in Python (Deep Copy and Shallow Copy) (geeksforgeeks的文章还是挺清楚的)
v. 一个工作三年的同事,居然还搞不清深拷贝、浅拷贝… (图解挺不错的)
vi. best_state changes with the model during training in pytorch (这位老哥想保存最佳模型参数,结果因为浅拷贝,导致保存的都是最后一轮的模型参数)
vii. Python中的赋值(复制)、浅拷贝与深拷贝 (这篇文章关于可变对象和不可对象的拷贝的id是否会改变进行了讨论)
✨ 原创不易,还希望各位大佬支持一下 \textcolor{blue}{原创不易,还希望各位大佬支持一下} 原创不易,还希望各位大佬支持一下
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/160159
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。