赞
踩
every blog every motto: Just live your life cause we don’t live twice.
c_loss = nn.CrossEntropyLoss()
loss0 = c_loss(d0, labels_v.long())
报错的原因是,在pytorch中做损失函数计算时,标签为(batch,height,width),如果类别为10类,那么其中的值应该 为 0~9,即:
0<= value<=C-1,其中C为通道数,或类别数
我的类别为10类,其中的值为1~10,所以只需要减1即可,如下图所示。
c_loss = nn.CrossEntropyLoss()
labels_v = labels_v-1
loss0 = c_loss(d0, labels_v.long())
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。