当前位置:   article > 正文

语义分割网络的one-hot编码(pytorch)_将3d语义分割图进行one-hot编码

将3d语义分割图进行one-hot编码
import torch
import numpy as np

gt = np.random.randint(0,5, size=[15,15])  #先生成一个15*15的label,值在5以内,意思是5类分割任务
gt = torch.LongTensor(gt)



def get_one_hot(label, N):
    size = list(label.size())
    label = label.view(-1)   # reshape 为向量
    ones = torch.sparse.torch.eye(N)
    ones = ones.index_select(0, label)   # 用上面的办法转为换one hot
    size.append(N)  # 把类别输目添到size的尾后,准备reshape回原来的尺寸
    return ones.view(*size)


gt_one_hot = get_one_hot(gt, 5)
print(gt_one_hot)
print(gt_one_hot.shape)

print(gt_one_hot.argmax(-1) == gt)  # 判断one hot 转换方式是否正确,全是1就是正确的
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/922719
推荐阅读
相关标签
  

闽ICP备14008679号