当前位置:   article > 正文

利用samples数据集跑通STANet

利用samples数据集跑通STANet

训练

命令

python ./train.py --gpu_ids -1 --num_threads 0 --save_epoch_freq 1 --dataroot ./samples --val_dataroot ./samples --name samples --lr 0.001 --model CDF0 --SA_mode PAM --batch_size 2 --load_size 256 --crop_size 256 --preprocess rotate_and_crop
  • 1

注意–gpu_ids -1 的意思是利用CPU
如果遇见下面的错误,请参考here

IndexError: boolean index did not match indexed array along dimension 0; dimension is 4194304 but corresponding boolean dimension is 65536

  • 1
  • 2

学习率收敛到0时,即停止
在这里插入图片描述

验证

测试

修改val.py,改成测试代码,我在这里重新建一个,命名为test_run.py

from options.test_options import TestOptions
from data import create_dataset
from models import create_model
import numpy as np
from util.util import mkdir
from util.util import save_images

def make_test_opt(opt):

    # hard-code some parameters for test
    opt.num_threads = 0   # test code only supports num_threads = 1
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    opt.no_flip2 = True    # no flip; comment this line if results on flipped images are needed.

    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.
    opt.phase = 'test'
    opt.preprocess = 'none1'
    opt.isTrain = False
    opt.aspect_ratio = 1
    opt.eval = True

    return opt

def prdeict(opt):

    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    model = create_model(opt)      # create a model given opt.model and other options
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    # save_path = os.path.join(opt.checkpoints_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # ./checkpoints/samples/test_epoch
    save_path = 'samples/output'
    mkdir(save_path)

    model.eval()
    for i, data in enumerate(dataset):
        # data的结构
        '''{'A': A, 'A_paths': A_path,
            'B': B, 'B_paths': B_path}'''
        model.set_input(data)  # unpack data from data loader
        pred = model.test(val=False)  # run inference return pred
        # print('pred=='+str(type(pred)))
        img_path = model.get_image_paths()     # get image paths
        # print('img_path=='+str(img_path))
        # print('save_path==' + str(save_path))
        save_images(pred, save_path, img_path)


if __name__ == '__main__':
    opt = TestOptions().parse()   # get training options
    opt = make_test_opt(opt)
    opt.phase = 'test' # 测试阶段
    opt.dataroot = 'samples'
    opt.dataset_mode = 'changedetection'
    opt.n_class = 2
    opt.SA_mode = 'PAM'
    opt.arch = 'mynet3' # 特征提取器架构
    opt.model = 'CDF0' # CDF0|CDFA
    opt.name = 'samples' # 保存路径的名字
    opt.results_dir = './results/'
    opt.epoch = '189_F1_1_0.95515' # 最佳模型的名字,去掉了_net_F.pth
    opt.num_test = np.inf

    opt.gpu_ids = False # 因为训练用的是CPU,然后,这里不能写-1
    # opt.istest = True # 判断是否为测试过程,适用于没有标签的情况;修改一点点代码,就不再需要这一个标签了,加上也无妨。
    prdeict(opt)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67

如果出现报错

pred是nonetype,空类型
  • 1

需要修改CFD0的test函数,因为在val为False时,没有给出返回值

    def test(self, val=False):
        """Forward function used in test time. 在测试时间中使用的正向功能。
        这个函数将<forward>函数封装在no_grad()中,这样我们就不会为反向运算保存中间步骤
        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        with torch.no_grad():
            self.forward()
            self.compute_visuals()
            if val:  # score
                from util.metrics import RunningMetrics
                metrics = RunningMetrics(self.n_class) # n_class是2,标签分为两类
                pred = self.pred_L.long()
                # detach()返回一个新的从当前图中分离的Variable
                metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
                scores = metrics.get_cm()
                return scores
            #增加下面两行
            else: #自己的加的
                return self.pred_L.long()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

结果
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/345219
推荐阅读
相关标签
  

闽ICP备14008679号