当前位置:   article > 正文

【AI论文复现】PanNet: A Deep Network Architecture for Pan-Sharpening(基于TensorFlow1.x)_深度学习pannet

深度学习pannet

《PanNet: A Deep Network Architecture for Pan-Sharpening》ICCV 2017

前言

深度学习函数包

  • MatconvNet
  • Caffee
  • TensorFlow1.0
  • TensorFlow2.0
  • PyTorch

Tensor -> 张量,即数据。Flow ->流。TensorFlow -> 数据在网络中流动…

用python配置GPU比较简单,一般代码现在CPU上测试好之后再在GPU上跑。

PanNet(卷积神经网络) <-> Pan-Sharpening(遥感图像融合)

pan图 + 低分辨率的多光谱图进行融合 = 高空间分辨率的多光谱图

一、环境配置

软件包安装

pip install tensorflow==1.15.0
pip install opencv-python
pip install scipy
  • 1
  • 2
  • 3
pip install opencv-python
pip install h5py
pip install scipy
  • 1
  • 2
  • 3

TensorFlow 2.0 中所有 tf.contrib 被弃用,所以要注意如果你复现的论文代码包含 tf.contrib,那么就要安装 TensorFlow 2.0 以下的版本!

二、论文精读

2.1、论文及代码获取

论文获取可以到这个网站:https://www.paperswithcode.com/

代码可以到github取,或者从这里下载…

2.2、论文精读

在这里插入图片描述

遥感图像融合(Pan-Sharpening):

卫星会从天上拍一个 全摄图(PAN图 64x64x1),同时还会拍一张 低空间分辨率多光谱图 (LRMS图 16x16x8),两者融合得到一个 高空间分辨率多光谱图 (HRMS图 64x64x8)

在这里插入图片描述

图像直接上采样会让图像变的模糊。

图像注释8个波段,但是实际4个波段,无所谓,示意图嘛,理解就好。

Training Data:PAN图、LRMS图、GT图

卷积神经网络(CNN):

在这里插入图片描述

上图等效于:

在这里插入图片描述

残差网络:

在这里插入图片描述

残差诸如高频等东西。

残差网络提出者:何凯明 CVPR 2016

2.3、代码精读

数据分两种:

  • 训练数据
    • train.mat【类型是4维张量tensor型的】
      • pan:100x64x64x1(100个数据)
      • ms:100x16x16x8(100个数据)
      • gt:100x64x64x8(100个数据)
      • lms(ms上采样4倍):100x64x64x8(100个数据)
    • validation.mat(调参用的)
  • 测试数据
    • pan
    • ms

train_data

在这里插入图片描述

test_data
在这里插入图片描述

batch_size = 32,意思是随机从100里面选出32个,即

  • pan:32x64x64x1(32个数据)
  • ms:32x16x16x8(32个数据)
  • gt:32x64x64x8(32个数据)
  • lms(ms上采样4倍):32x64x64x8(32个数据)

分成batch_size算,效率更高,效果更好

# num_fm = 32即表示卷积的kernel的个数是32个
# stride = 1 即下一个像素的位移是1
rs = ly.conv2d(ms, num_outputs = num_fm, kernel_size = 3, stride = 1, 
			  weights_regularizer = ly.l2_regularizer(weight_decay), 
			  weights_initializer = ly.variance_scaling_initializer(),
			  activation_fn = tf.nn.relu)   # 32x 64 x 64 x32
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

含义如下:
在这里插入图片描述
如果stride = 2,隔1个卷,那个最终输出32x32x32
stride=4,,最终输出(64/4)x(64/4)x32

残差网络

for i in range(num_res):   # ResNet
    rs1 = ly.conv2d(rs, num_outputs = num_fm, kernel_size = 3, stride = 1, 
                    weights_regularizer = ly.l2_regularizer(weight_decay), 
                    weights_initializer = ly.variance_scaling_initializer(),
                    activation_fn = tf.nn.relu) # 32  x 64 x64 x32 #ResNet的第一个卷积+relu(非线性函数)
    
    rs1 = ly.conv2d(rs1, num_outputs = num_fm, kernel_size = 3, stride = 1, 
                    weights_regularizer = ly.l2_regularizer(weight_decay), 
                    weights_initializer = ly.variance_scaling_initializer(),
                    activation_fn = None) # 32 x 64 x64 x32 #RestNet的第二个卷积,没有relu
    
    rs = tf.add(rs,rs1)   # ResNet:  32 x 64 x64 x32 #输出要与x相加,对照ResNet图
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

对比下图理解:
在这里插入图片描述

1个ResNet的block,对应两层卷积;共有 num_res=4 个 block!

PANNET网络一共10个卷积层(1 +(2x4)+ 1)= 10

知道了loss现在用 Adam or SGD 来算θ

测试数据的data
在这里插入图片描述

注意其中的pan,变成了256x256,而ms是64x64x8

test可以测任意大小的数据,不仅限于16x16x8


附录

train.py

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

"""
# This is a re-implementation of training code of this paper:
# J. Yang, X. Fu, Y. Hu, Y. Huang, X. Ding, J. Paisley. "PanNet: A deep network architecture for pan-sharpening", ICCV,2017. 
# author: Junfeng Yang
"""

import tensorflow as tf
import numpy as np # 科学计算 数组值
import cv2 #opencv
import tensorflow.contrib.layers as ly #TensorFlow1.x 中计算卷积的
import os
import h5py
import scipy.io as sio # 读矩阵存矩矩阵
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 为0表示调用第一张GPU


# get high-frequency (high-pass)
# 高通滤波
def get_edge(data):  
    rs = np.zeros_like(data)
    N = data.shape[0]
    for i in range(N):
        if len(data.shape)==3:
            rs[i,:,:] = data[i,:,:] - cv2.boxFilter(data[i,:,:],-1,(5,5)) # 数据 - 低频信息
        else:
            rs[i,:,:,:] = data[i,:,:,:] - cv2.boxFilter(data[i,:,:,:],-1,(5,5))
    return rs


 # get training patches
def get_batch(train_data,bs): 
    
    gt = train_data['gt'][...]    ## ground truth N*H*W*C
    pan = train_data['pan'][...]  #### Pan image N*H*W
    ms_lr = train_data['ms'][...] ### low resolution MS image
    lms   = train_data['lms'][...]   #### MS image interpolation to Pan scale
    
    gt = np.array(gt,dtype = np.float32) / 2047.  ### normalization, WorldView L = 11
    pan = np.array(pan, dtype = np.float32) /2047.
    ms_lr = np.array(ms_lr, dtype = np.float32) / 2047.
    lms  = np.array(lms, dtype = np.float32) /2047.

    
    N = gt.shape[0]
    batch_index = np.random.randint(0,N,size = bs)
    
    gt_batch = gt[batch_index,:,:,:]
    pan_batch = pan[batch_index,:,:]
    ms_lr_batch = ms_lr[batch_index,:,:,:]
    lms_batch  = lms[batch_index,:,:,:]
    
    pan_hp_batch = get_edge(pan_batch)
    pan_hp_batch = pan_hp_batch[:,:,:,np.newaxis] # expand to N*H*W*1
    
    ms_hp_batch = get_edge(ms_lr_batch)
    
    return gt_batch, lms_batch, pan_hp_batch, ms_hp_batch


def vis_ms(data): # 显示数据 数据是8个channel,现实生活中显示RGB,3个channel
    _,b,g,_,r,_,_,_ = tf.split(data,8,axis = 3)
    vis = tf.concat([r,g,b],axis = 3)
    return vis


########## PanNet structures ################ # 核心
def PanNet(ms, pan, num_spectral = 8, num_res = 4, num_fm = 32, reuse=False):
    
    weight_decay = 1e-5 #做训练时的1个参数,不深究
    #with tf.device('/gpu:0'):

    with tf.variable_scope('net'):        
        if reuse:
            tf.get_variable_scope().reuse_variables()
            
        #ms本来是32x16x16x8,下面该语句中的4表示上采样4倍,即32x64x64x8
        ms = ly.conv2d_transpose(ms,num_spectral,8,4,activation_fn = None,   # 32 x 64 x64 x8
                                 weights_initializer = ly.variance_scaling_initializer(), 
                                 weights_regularizer = ly.l2_regularizer(weight_decay))
        ms = tf.concat([ms,pan],axis=3)  # ms + pan: put together (concat) : 32 x 64 x64 x9 axis从0开始数,故此处的3表示第4维度

        # num_fm = 32即表示卷积的kernel的个数是32个 kernel: 3x3x32
        # 进 ResNet之前先卷积一下
        rs = ly.conv2d(ms, num_outputs = num_fm, kernel_size = 3, stride = 1,
                          weights_regularizer = ly.l2_regularizer(weight_decay),
                          weights_initializer = ly.variance_scaling_initializer(),
                          activation_fn = tf.nn.relu)   # 32x 64 x 64 x32
        
        for i in range(num_res):   # ResNet
            #  kernel: 3x3x32
            rs1 = ly.conv2d(rs, num_outputs = num_fm, kernel_size = 3, stride = 1, 
                            weights_regularizer = ly.l2_regularizer(weight_decay), 
                            weights_initializer = ly.variance_scaling_initializer(),
                            activation_fn = tf.nn.relu) # 32  x 64 x64 x32 #ResNet的第一个卷积+relu(非线性函数)

            #  kernel: 3x3x32
            rs1 = ly.conv2d(rs1, num_outputs = num_fm, kernel_size = 3, stride = 1, 
                            weights_regularizer = ly.l2_regularizer(weight_decay), 
                            weights_initializer = ly.variance_scaling_initializer(),
                            activation_fn = None) # 32 x 64 x64 x32 #RestNet的第二个卷积,没有relu
            
            rs = tf.add(rs,rs1)   # ResNet:  32 x 64 x64 x32 #输出要与x相加,对照ResNet图

        #  kernel: 3x3x8
        rs = ly.conv2d(rs, num_outputs = num_spectral, kernel_size = 3, stride = 1, 
                           weights_regularizer = ly.l2_regularizer(weight_decay), 
                           weights_initializer = ly.variance_scaling_initializer(),
                           activation_fn = None)  # 32 x 64 x64 x8
        return rs

 ###########################################################################
 ###########################################################################
 ########### Main Function: input data from here! (likes sub-funs in matlab before) ######

if __name__ =='__main__':

    tf.reset_default_graph()   

    train_batch_size = 32 # training batch size
    test_batch_size = 32  # validation batch size
    image_size = 64      # patch size 64x64x8这里写64即可,100x100x8这里写100即可
    iterations = 100100 # total number of iterations to use.
    model_directory = './models' # directory to save trained model to.
    train_data_name = './training_data/train.mat'  # training data
    test_data_name  = './training_data/validation.mat'   # validation data
    restore = False  # load model or not
    method = 'Adam'  # training method: Adam or SGD 最小Loss计算时用的策略,直接调用即可...
    
############## loading data
    train_data = sio.loadmat(train_data_name)   # for small data (not v7.3 data)
    test_data = sio.loadmat(test_data_name)
    
    #train_data = h5py.File(train_data_name)  # for large data ( v7.3 data)
    #test_data  = h5py.File(test_data_name)

############## placeholder for training ########### #placeholder 占位符,数据等后面再塞进来
    gt = tf.placeholder(dtype = tf.float32,shape = [train_batch_size,image_size,image_size,8]) # 32x64x64x8
    lms = tf.placeholder(dtype = tf.float32,shape = [train_batch_size,image_size,image_size,8])
    ms_hp = tf.placeholder(dtype = tf.float32,shape = [train_batch_size,image_size//4,image_size//4,8])#32x16x16x8
    pan_hp = tf.placeholder(dtype = tf.float32,shape = [train_batch_size,image_size,image_size,1])

############# placeholder for testing ##############
    test_gt = tf.placeholder(dtype = tf.float32,shape = [test_batch_size,image_size,image_size,8])
    test_lms = tf.placeholder(dtype = tf.float32,shape = [test_batch_size,image_size,image_size,8])
    test_ms_hp = tf.placeholder(dtype = tf.float32,shape = [test_batch_size,image_size//4,image_size//4,8])
    test_pan_hp = tf.placeholder(dtype = tf.float32,shape = [test_batch_size,image_size,image_size,1])

######## network architecture (call: PanNet constructed before!) ######################
    mrs = PanNet(ms_hp,pan_hp)    # call pannet
    mrs = tf.add(mrs,lms)        # 32 x64 x64 x8
    
    test_rs = PanNet(test_ms_hp,test_pan_hp,reuse = True)
    test_rs = test_rs + test_lms  # same as: test_rs = tf.add(test_rs,test_lms) 

######## loss function ################
    mse = tf.reduce_mean(tf.square(mrs - gt))  # compute cost : loss 2范数
    test_mse = tf.reduce_mean(tf.square(test_rs - test_gt))

##### Loss summary (for observation) ################ 为了显示用的,直接copy即可,注意参数,所以数据归一化到0和1 之间了
    mse_loss_sum = tf.summary.scalar("mse_loss",mse)

    test_mse_sum = tf.summary.scalar("test_loss",test_mse)

    lms_sum = tf.summary.image("lms",tf.clip_by_value(vis_ms(lms),0,1))
    mrs_sum = tf.summary.image("rs",tf.clip_by_value(vis_ms(mrs),0,1))

    label_sum = tf.summary.image("label",tf.clip_by_value(vis_ms(gt),0,1))
    
    all_sum = tf.summary.merge([mse_loss_sum,mrs_sum,label_sum,lms_sum])

############ optimizer: Adam or SGD ################## 知道了loss现在用 Adam or SGD 来算θ,copy即可
    t_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'net')    

    if method == 'Adam':
        g_optim = tf.train.AdamOptimizer(0.001, beta1 = 0.9) \
                          .minimize(mse, var_list=t_vars)

    else:
        global_steps = tf.Variable(0,trainable = False)
        lr = tf.train.exponential_decay(0.1,global_steps,decay_steps = 50000, decay_rate = 0.1)
        clip_value = 0.1/lr
        optim = tf.train.MomentumOptimizer(lr,0.9)
        gradient, var   = zip(*optim.compute_gradients(mse,var_list = t_vars))
        gradient, _ = tf.clip_by_global_norm(gradient,clip_value)
        g_optim = optim.apply_gradients(zip(gradient,var),global_step = global_steps)
        
##### GPU setting copy即可
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

###########################################################################
###########################################################################
#### Run the above (take real data into the Net, for training) ############ Session来注入数据,让网络全部流动起来

    init = tf.global_variables_initializer()  # initialization: must done!

    saver = tf.train.Saver()
    with tf.Session() as sess:  
        sess.run(init)
 
        if restore:
            print ('Loading Model...')
            ckpt = tf.train.get_checkpoint_state(model_directory)
            saver.restore(sess,ckpt.model_checkpoint_path)

        #### read training data #####
        gt1 = train_data['gt'][...]  ## ground truth N*H*W*C
        pan1 = train_data['pan'][...]  #### Pan image N*H*W
        ms_lr1 = train_data['ms'][...]  ### low resolution MS image
        lms1 = train_data['lms'][...]  #### MS image interpolation to Pan scale

        gt1 = np.array(gt1, dtype=np.float32) / 2047.  ### [0, 1] normalization, WorldView L = 11
        pan1 = np.array(pan1, dtype=np.float32) / 2047.
        ms_lr1 = np.array(ms_lr1, dtype=np.float32) / 2047.
        lms1 = np.array(lms1, dtype=np.float32) / 2047.

        N = gt1.shape[0]

        #### read validation data #####
        gt2 = test_data['gt'][...]  ## ground truth N*H*W*C
        pan2 = test_data['pan'][...]  #### Pan image N*H*W
        ms_lr2 = test_data['ms'][...]  ### low resolution MS image
        lms2 = test_data['lms'][...]  #### MS image interpolation -to Pan scale

        gt2 = np.array(gt2, dtype=np.float32) / 2047.  ### normalization, WorldView L = 11
        pan2 = np.array(pan2, dtype=np.float32) / 2047.
        ms_lr2 = np.array(ms_lr2, dtype=np.float32) / 2047.
        lms2 = np.array(lms2, dtype=np.float32) / 2047.
        N2 = gt2.shape[0]

        mse_train = [] # mse误差,一会画误差图用的
        mse_valid = []
        
        for i in range(iterations): # 进入训练阶段
            ###################################################################
            #### training phase! ###########################

            bs = train_batch_size
            batch_index = np.random.randint(0, N, size=bs)  # N = 100; choose bs = 32 100各种随机选32个

            train_gt = gt1[batch_index, :, :, :]
            pan_batch = pan1[batch_index, :, :]
            ms_lr_batch = ms_lr1[batch_index, :, :, :]
            train_lms = lms1[batch_index, :, :, :]

            pan_hp_batch = get_edge(pan_batch)   # 32x 64 x 64 高通滤波
            train_pan_hp = pan_hp_batch[:, :, :, np.newaxis]  # expand to N*H*W*1: 32 x64 x64 x1 扩展成4维

            train_ms_hp = get_edge(ms_lr_batch) # 32 x16 x16 x8


            #train_gt, train_lms, train_pan_hp, train_ms_hp = get_batch(train_data, bs = train_batch_size)

            # 数据在网络中跑起来,数据赋给占位符 第一个参数占位符 第二个参数是load的数据
            _,mse_loss,merged = sess.run([g_optim,mse,all_sum],feed_dict = {gt: train_gt, lms: train_lms,
                                         ms_hp: train_ms_hp, pan_hp: train_pan_hp})

            mse_train.append(mse_loss)   # record the mse of trainning 没训练1步,存一下误差

            if i % 100 == 0: # 每100步打印一下loss,按道理来说loss需要一直下降

                print ("Iter: " + str(i) + " MSE: " + str(mse_loss))   # print, e.g.,: Iter: 0 MSE: 0.18406609

            if i % 5000 == 0 and i != 0: # 每5000步,存一下model,即卷积核的参数 .ckpt 格式
                if not os.path.exists(model_directory):
                    os.makedirs(model_directory)
                saver.save(sess,model_directory+'/model-'+str(i)+'.ckpt')
                print ("Save Model")

            ###################################################################
            #### validation phase! ###########################

            bs_test = test_batch_size
            batch_index2 = np.random.randint(0, N, size=bs_test)

            test_gt_batch = gt2[batch_index2, :, :, :]

            test_lms_batch = lms2[batch_index2, :, :, :]

            ms_lr_batch = ms_lr2[batch_index2, :, :, :]
            test_ms_hp_batch = get_edge(ms_lr_batch)

            pan_batch = pan2[batch_index2, :, :]
            pan_hp_batch = get_edge(pan_batch)
            test_pan_hp_batch = pan_hp_batch[:, :, :, np.newaxis]  # expand to N*H*W*1


            '''if i%1000 == 0 and i!=0:  # after 1000 iteration, re-set: get_batch
                test_gt_batch, test_lms_batch, test_pan_hp_batch, test_ms_hp_batch = get_batch(test_data, bs = test_batch_size)'''
                
            test_mse_loss,merged = sess.run([test_mse,test_mse_sum],
                                               feed_dict = {test_gt : test_gt_batch, test_lms : test_lms_batch,
                                                            test_ms_hp : test_ms_hp_batch, test_pan_hp : test_pan_hp_batch})

            mse_valid.append(test_mse_loss)  # record the mse of trainning

            if i % 1000 == 0 and i != 0: # 每1000步打印一下误差
                print("Iter: " + str(i) + " Valid MSE: " + str(test_mse_loss))  # print, e.g.,: Iter: 0 MSE: 0.18406609
                
        ## finally write the mse info ##
        file = open('train_mse.txt','w')  # write the training error into train_mse.txt 误差存起来
        file.write(str(mse_train))
        file.close()

        file = open('valid_mse.txt','w')  # write the valid error into valid_mse.txt
        file.write(str(mse_valid))
        file.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312

test.py

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

"""
# This is a re-implementation of training code of this paper:
# J. Yang, X. Fu, Y. Hu, Y. Huang, X. Ding, J. Paisley. "PanNet: A deep network architecture for pan-sharpening", ICCV,2017. 
# author: Junfeng Yang

"""
import tensorflow as tf
import tensorflow.contrib.layers as ly
import numpy as np
import scipy.io as sio
import cv2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["TF_CPP_MIN_LOG_LEVEL"]='3' # 只显示 warning 和 Error

def PanNet(ms, pan, num_spectral = 8, num_res = 4, num_fm = 32, reuse=False):
    
    weight_decay = 1e-5
    with tf.variable_scope('net'):        
        if reuse:
            tf.get_variable_scope().reuse_variables()
            
        
        ms = ly.conv2d_transpose(ms,num_spectral,8,4,activation_fn = None, weights_initializer = ly.variance_scaling_initializer(), 
                                 biases_initializer = None,
                                 weights_regularizer = ly.l2_regularizer(weight_decay))
        ms = tf.concat([ms,pan],axis=3)

        rs = ly.conv2d(ms, num_outputs = num_fm, kernel_size = 3, stride = 1, 
                          weights_regularizer = ly.l2_regularizer(weight_decay), 
                          weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu)
        
        for i in range(num_res):
            rs1 = ly.conv2d(rs, num_outputs = num_fm, kernel_size = 3, stride = 1, 
                          weights_regularizer = ly.l2_regularizer(weight_decay), 
                          weights_initializer = ly.variance_scaling_initializer(),activation_fn = tf.nn.relu)
            rs1 = ly.conv2d(rs1, num_outputs = num_fm, kernel_size = 3, stride = 1, 
                          weights_regularizer = ly.l2_regularizer(weight_decay), 
                          weights_initializer = ly.variance_scaling_initializer(),activation_fn = None)
            rs = tf.add(rs,rs1)
        
        rs = ly.conv2d(rs, num_outputs = num_spectral, kernel_size = 3, stride = 1, 
                          weights_regularizer = ly.l2_regularizer(weight_decay), 
                          weights_initializer = ly.variance_scaling_initializer(),activation_fn = None)
        return rs

def get_edge(data): # get high-frequency
    rs = np.zeros_like(data)
    if len(rs.shape) ==3:
        for i in range(data.shape[2]):
            rs[:,:,i] = data[:,:,i] -cv2.boxFilter(data[:,:,i],-1,(5,5))
    else:
        rs = data - cv2.boxFilter(data,-1,(5,5))
    return rs


#################################################################
################# Main fucntion ################################## main之前的函数与train一模一样
if __name__=='__main__':

    test_data = 'new_data.mat'

    model_directory = './models/'

    tf.reset_default_graph() # 默认的东西,先不深究
    
    data = sio.loadmat(test_data)
    
    ms = data['ms'][...]      # MS image 64x64x8
    ms = np.array(ms,dtype = np.float32) /2047.


    lms = data['lms'][...]    # up-sampled LRMS image 256x256x8
    lms = np.array(lms, dtype = np.float32) /2047.

    pan  = data['pan'][...]  # PAN image 256x256
    pan  = np.array(pan,dtype = np.float32) /2047.
    
     
    ms_hp = get_edge(ms)   # high-frequency parts of MS image
    ms_hp = ms_hp[np.newaxis,:,:,:]  # 1x64x64x8 #补1个维数

    pan_hp = get_edge(pan) # high-frequency parts of PAN image: 256x256
    pan_hp = pan_hp[np.newaxis,:,:,np.newaxis]  # 1x256x256x1 #补2个维数

    h = pan.shape[0] # height
    w = pan.shape[1] # width

    lms   = lms[np.newaxis,:,:,:]  # 1x256x256x8 #补1个维数
    
##### placeholder for testing#######
    p_hp = tf.placeholder(shape=[1,h,w,1],dtype=tf.float32)
    m_hp = tf.placeholder(shape=[1,h/4,w/4,8],dtype=tf.float32)
    lms_p = tf.placeholder(shape=[1,h,w,8],dtype=tf.float32)


    rs = PanNet(m_hp,p_hp) # output high-frequency parts 丢ms_hp pan_hp 到网络中,输出高分辨率结果图
    
    mrs = tf.add(rs,lms_p) 
    
    output = tf.clip_by_value(mrs,0,1) # final output 大于1的变成1,小于0变成0



################################################################
##################Session Run ##################################
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    
    with tf.Session() as sess:  
        sess.run(init)
        
        # loading  model       
        if tf.train.get_checkpoint_state(model_directory):  
           ckpt = tf.train.latest_checkpoint(model_directory)
           saver.restore(sess, ckpt)
           print ("load new model")

        else:
           ckpt = tf.train.get_checkpoint_state(model_directory + "pre-trained/")
           saver.restore(sess,ckpt.model_checkpoint_path) # this model uses 128 feature maps and for debug only                                                                   
           print ("load pre-trained model")                            
        


        final_output = sess.run(output,feed_dict = {p_hp:pan_hp, m_hp:ms_hp, lms_p:lms})# 1x256x256x8

        sio.savemat('./result/output.mat', {'output':final_output[0,:,:,:]}) #256x256x8 存成.mat格式
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/592594
推荐阅读
相关标签
  

闽ICP备14008679号