赞
踩
TensorFlow网络在输入Numpy数据时会自动转换为Tensor来处理,但是我们自己也可以去显式的转换:
data_tensor= tf.convert_to_tensor(data_numpy)
网络输出的结果仍为Tensor,当我们要用这些结果去执行只能由Numpy数据来执行的操作时就会出现莫名其妙的错误。解决方法
with tf.Session() as sess:
data_numpy = data_tensor.eval()
数据集的处理,卷积。conv2d
import numpy as np import tensorflow as tf from sklearn.datasets import load_sample_images import matplotlib.pyplot as plt #输入的照片是三通道的(三维),读取数据 daraset=np.array(load_sample_images().images,dtype=np.float32) #数据集中有一个寺庙的照片和一个画的照片 b_size,height,weight,channel=daraset.shape#参数分别为图片数目,图片的宽,图片的高,图片的通道 print(b_size,height,weight,channel) #创建两个卷积核 filter1=np.zeros(shape=(7,7,channel,2),dtype=np.float32) filter1[:,3,:,0]=1 filter1[3,:,:,1]=2 # filter参数是一个filters的集合 X = tf.placeholder(tf.float32, shape=(None, height, weight, channel)) #经过卷积后的层 conv=tf.nn.conv2d(X,filter=filter1,strides=[1,2,2,1],padding="SAME") with tf.Session() as sess: output=sess.run(conv,feed_dict={X:daraset}) plt.imshow(output[0,:,:,0])#第一个图的第一个特征图 plt.show() #第一个图的第二个特征图 plt.imshow(output[0,:,:,1]) plt.show() plt.imshow(output[1,:,:,0])#第一个图的第一个特征图 plt.show() #第一个图的第二个特征图 plt.imshow(output[1,:,:,1]) plt.show()
对图片进行池化操作,max_pool最大池化。
import numpy as np from sklearn.datasets import load_sample_images import tensorflow as tf import matplotlib.pyplot as plt # 加载数据集 # 输入图片通常是3D,[height, width, channels] # mini-batch通常是4D,[mini-batch size, height, width, channels] dataset = np.array(load_sample_images().images, dtype=np.float32) # 数据集里面两张图片,一个中国庙宇,一个花 batch_size, height, width, channels = dataset.shape print(batch_size, height, width, channels) # 创建输入和一个池化层 X = tf.placeholder(tf.float32, shape=(None, height, width, channels)) # TensorFlow不支持池化多个实例,所以ksize的第一个batch size是1 # TensorFlow不支持池化同时发生的长宽高,所以必须有一个是1,这里channels就是depth维度为1 max_pool = tf.nn.max_pool(X, ksize=[1, 4, 4, 1], strides=[1, 4, 4, 1], padding='VALID') # avg_pool() with tf.Session() as sess: output = sess.run(max_pool, feed_dict={X: dataset}) plt.imshow(dataset[0].astype(np.uint8)) plt.show() plt.imshow(output[0].astype(np.uint8)) # 画输入的第一个图像 plt.show() plt.imshow(dataset[1].astype(np.uint8)) plt.show() plt.imshow(output[1].astype(np.uint8)) # 画输入的第一个图像 plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。