SRCNN—test
--srcnn = SRCNN() #初始化
------参数初始化
------build_model() #定义模型的输入输出,初始化超参数,网络结构,loss,并实例化一个 Saver 对象
|--------self.pred = self.model() #只要调用了 model 函数就会有输出,返回三层卷积后的结果,网络输出
--train(self, config)
------nx, ny = input_setup(self.sess, config) #将输入和输出图片切成小块保存成 HDF5 文件,返回每行每列能裁
剪出的子图个数
|--------prepare_data(sess, dataset="Test") #返回测试集图片地址列表,sess 无用
------------data_dir = os.path.join(…) #os.getcwd():得到当前文件路径
------------data = glob.glob(os.path.join(data_dir, "*.bmp")) #图片路径列表
|--------input_, label_ = preprocess(data[0], config.scale) #预处理图片,只对其中的一张图片做处理。
,返回(低,高)分辨率图像对
------------image = imread(path, is_grayscale=True) #以 YCbCr 格式读取原始图像(默认为灰度)
------------label_ = modcrop(image, scale) #将图片规整到可以被 scale 整除的宽和高
------------image = image / 255. #无用,可删掉
------------label_ = label_ / 255. #归一化
------------scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False) #进行三次插值缩小 scale 倍
------------scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False) #进行三次插值放大回原有大小
|--------nx = ny = 0 #宽度和高度方向上分别裁剪出的子图数量
|--------滑动窗口裁剪:将低分辨率图像和高分辨率图像分别裁剪成 33*33 和 21*21 的小图,步长 stride=6。
|--------重定义类型(reshape)后添加到输入和输出列表
|--------np.asarray() #将输入列表和输出列表中的数据转为 numpy 类型
|--------make_data(sess, arrdata, arrlabel) #制作 h5 文件,这个是产生训练数据的函数
------------savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5') #定义 h5 文件保存地址
------------hf.create_dataset(name,data) #建立一个名为 name(字符串),数据为 data(numpy 数组)的数据集
------train_data, train_label = read_data(data_dir) #获取 HDF5 格式测试集
------tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss) #建立 SGD 优化器
------tf.initialize_all_variables().run() #初始化所有参数
------self.load(self.checkpoint_dir) #加载训练过的参数
|--------定义路径为:checkpoint_dir/srcnn_21
|--------ckpt=tf.train.get_checkpoint_state(checkpoint_dir) #从文件夹中获取 checkpoint 文件
|--------ckpt_name = os.path.basename(ckpt.model_checkpoint_path) #获取 checkpoint 文件的文件名
|--------saver.restore(…) #重载模型的参数,继续训练或者用于测试数据
------result = self.pred.eval({self.images: train_data, self.labels: train_label}) #网络输出
------result = merge(result, [nx, ny]) #将一个 batch 内的图片拼接到一起
------result = result.squeeze() #squeeze 去除维度为 1 的地方
------image_path = os.path.join(os.getcwd(), config.sample_dir) #生成图片的保存地址
------image_path = os.path.join(image_path, "test_image.png") #保存地址/图片名
------imsave(result, image_path) #保存图片