赞
踩
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
'''
定义一个转化函数,目的就是可以把我们常见的图片文件(jpg,png等)通过torchvision.transforms.ToTensor()函数
转化成tensor格式,然后再通过torchvision.transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
函数进行图片归一化,即将tensor中的值都限定在-1到1之间。
'''
transfroms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()],
torchvision.transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]))
'''
读取目录下的图片文件,你的文件夹格式应该为如下情况:
directory/
├── class_x
│ ├── xxx.jpg
│ ├── xxy.jpg
│ └── ...
│ └── xxz.jpg
└── class_y
├── 123.jpg
├── nsdf3.jpg
└── ...
└── asd932_.jpg
即,directory下面有各种各样的类别文件夹,然后在这个类别文件夹下才是你自己的图片文件,之所以需要这样
是因为,它这个torchvision.datasets.DatasetFolder函数能够根据你的class_x的名字来自动定义标签,
如上,class_x文件夹下的图片的标签就是0,class_y下的图片标签就是1,以此类推。
'''
imagenet_dataset = torchvision.datasets.ImageFolder('path/to/directory/', transform=transfroms )
'''
最后就是构造一个迭代器了,batch_size就是接下来要批量处理的数量,shffle标志是否给imagenet_data中的数据打乱,顺序随机一下。
'''
data_loader = torch.utils.data.DataLoader(imagenet_dataset ,
batch_size=4,
shuffle=True)
'''
最后就是可以处理数据了,我们先假设我们已经构造好了模型为model, 损失函数为criterion,优化函数为optimizer
'''
model = ’自己的网络结构‘
optimizer = optim.Adam(model.parameters(), lr=0.0001) # 优化方法,学习率为lr的值
criterion = nn.MSELoss() # 损失函数,即计算模型的预测值和真实标签label的差异大小的函数
epoch = 200 # 将整个数据集训练200次
for i in range(epoch):
for data in data_loader:
imgs, labels = data #提出图片和该图片的标签
# ======前向传播=======
outputs = model(img)
loss = criterion(outputs, labels)
# ======反向传播=======
optimizer.zero_grad()
loss.backward()
optimizer.step()
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
'''
定义一个转化函数,目的就是可以把我们常见的图片文件(jpg,png等)通过torchvision.transforms.ToTensor()函数
转化成tensor格式,然后再通过torchvision.transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
函数进行图片归一化,即将tensor中的值都限定在-1到1之间。
'''
transfroms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()],
torchvision.transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]))
'''
读取目录下的图片文件,你的文件夹格式应该为如下情况:
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
即,directory下面有各种各样的类别文件夹,然后在这个类别文件夹下才是你自己的图片文件,之所以需要这样
是因为,它这个torchvision.datasets.DatasetFolder函数能够根据你的class_x的名字来自动定义标签,
如上,class_x文件夹下的图片的标签就是0,class_y下的图片标签就是1,以此类推。
'''
my_dataset = torchvision.datasets.DatasetFolder('path/to/directory/', transform=transfroms )
'''
最后就是构造一个迭代器了,batch_size就是接下来要批量处理的数量,shffle标志是否给imagenet_data中的数据打乱,顺序随机一下。
'''
data_loader = torch.utils.data.DataLoader(my_dataset ,
batch_size=4,
shuffle=True)
'''
之后就是调用模型,遍历这些数据集进行训练操作了,方法见方法一
'''
因为我们使用官方的类来读取数据集多多少少不能够满足我们的个性化需求,比如我们的数据集一定要是这样的格式:
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
这样的数据集格式就存在了一个问题,这只适用于有标签的数据集,然后不同的类别的图片还需要放到不同的文件夹下。那么我们如果是要进行无监督数据学习呢?也就是没有标签,这样就需要我们自己定义一个datasets类,来读取我们的个性化数据集。
比如我的数据是无标签的,数据格式如下:
directory/
├── xxx.png
├── xxy.png
└── ...
└── xxz.png
from torch.utils.data import Dataset
from PIL import Image
import os
'''
该类继承了Dataset,必须重定义__getitem__()和__len__()函数
'''
class Mydata_sets(Dataset):
# 初始化函数
def __init__(self, path, transform=None):
super(Mydata_sets, self).__init__()
self.root_dir = path
self.img_path = os.listdir(self.root_dir)
self.transform = transform
# 获取数据的方法
def __getitem__(self, index):
img_name = self.img_path[index]
img = Image.open(os.path.join(self.root_dir, img_name))
if self.transform is not None:
img = self.transform(img)
return img
# 可以理解为,告诉Dataset类,我们的数据有多少个
def __len__(self):
return len(self.img_path)
my_datasets = Mydata_sets(pic_path, transform=transform)
imgLoader = torch.utils.data.DataLoader(my_datasets, batch_size=128, shuffle=False, num_workers=1)
有可能我们的数据类别标签与元数据分别存放在两个txt里。
参文原文链接:https://blog.csdn.net/WMT834161117/article/details/118029358
root1 = r"C:\Users\asus\Desktop\mstar_classification\mstar\train.txt"
root2 = r"C:\Users\asus\Desktop\mstar_classification\mstar\val.txt"
class Mydata_sets(Dataset):
def __init__(self, txt, transform=None, target_transform=None):
super(Mydata_sets, self).__init__()
self.txt = txt
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1]))) # imgs中包含有图像路径和标签
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(os.path.join(self.txt[:-4], fn))
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
# 定义类
class Mydata_sets(Dataset):
def __init__(self, path, transform=None):
super(Mydata_sets, self).__init__()
self.root_dir = path
self.img_names = os.listdir(self.root_dir)
self.transform = transform
def __getitem__(self, index):
img_name = self.img_names[index]
img = Image.open(os.path.join(self.root_dir, img_name))
id_name = torch.tensor(int(self.img_names[index][4:-4])) #pic_xx.png
if self.transform is not None:
img = self.transform(img)
return img, id_name
def __len__(self):
return len(self.img_names)
# 处理示例
transform = transforms.Compose([
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
ResNet50 = torch.load("./model_files/checkpoints/classify_model/ResNet50.pt", map_location=device)
ResNet50 = ResNet50.to(device)
ResNet50.eval()
features, labels, ids = [], [], [] # features:提取的2048维图片特征,labels:模型预测的标签,ids图片文件编号
feature_model = copy.deepcopy(ResNet50)
feature_model.fc = nn.Identity() # 相当于取消fc层, 这样
label_model = copy.deepcopy(ResNet50)
# 图片路径
pic_path = "./static/data/pic/random_50k"
img_datasets = Mydata_sets(pic_path, transform=transform)
imgLoader = torch.utils.data.DataLoader(img_datasets, batch_size=128, shuffle=False, num_workers=4) # 指定读取配置信息
with torch.no_grad():
for x, y in tqdm(imgLoader):
x = x.to(device)
ids.append(y) # N
feature = feature_model(x) # N, 2048
features.append(feature)
ten_D = label_model(x)
label = torch.argmax(ten_D, dim=1)
labels.append(label)
features = torch.cat(features, dim=0).squeeze().cpu().numpy() # (n, 2048)
labels = torch.cat(labels, dim=0).cpu().numpy() # n
ids = torch.cat(ids, dim=0).cpu().numpy()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。