logo资料库

SRCNN代码解析_test.pdf

第1页 / 共3页
第2页 / 共3页
第3页 / 共3页
资料共3页,全文预览结束
SRCNN—test --srcnn = SRCNN() #初始化 ------参数初始化 ------build_model() #定义模型的输入输出,初始化超参数,网络结构,loss,并实例化一个 Saver 对象 |--------self.pred = self.model() #只要调用了 model 函数就会有输出,返回三层卷积后的结果,网络输出 --train(self, config) ------nx, ny = input_setup(self.sess, config) #将输入和输出图片切成小块保存成 HDF5 文件,返回每行每列能裁 剪出的子图个数 |--------prepare_data(sess, dataset="Test") #返回测试集图片地址列表,sess 无用 ------------data_dir = os.path.join(…) #os.getcwd():得到当前文件路径 ------------data = glob.glob(os.path.join(data_dir, "*.bmp")) #图片路径列表 |--------input_, label_ = preprocess(data[0], config.scale) #预处理图片,只对其中的一张图片做处理。 ,返回(低,高)分辨率图像对 ------------image = imread(path, is_grayscale=True) #以 YCbCr 格式读取原始图像(默认为灰度) ------------label_ = modcrop(image, scale) #将图片规整到可以被 scale 整除的宽和高 ------------image = image / 255. #无用,可删掉 ------------label_ = label_ / 255. #归一化
------------scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False) #进行三次插值缩小 scale 倍 ------------scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False) #进行三次插值放大回原有大小 |--------nx = ny = 0 #宽度和高度方向上分别裁剪出的子图数量 |--------滑动窗口裁剪:将低分辨率图像和高分辨率图像分别裁剪成 33*33 和 21*21 的小图,步长 stride=6。 |--------重定义类型(reshape)后添加到输入和输出列表 |--------np.asarray() #将输入列表和输出列表中的数据转为 numpy 类型 |--------make_data(sess, arrdata, arrlabel) #制作 h5 文件,这个是产生训练数据的函数 ------------savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5') #定义 h5 文件保存地址 ------------hf.create_dataset(name,data) #建立一个名为 name(字符串),数据为 data(numpy 数组)的数据集 ------train_data, train_label = read_data(data_dir) #获取 HDF5 格式测试集 ------tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss) #建立 SGD 优化器 ------tf.initialize_all_variables().run() #初始化所有参数 ------self.load(self.checkpoint_dir) #加载训练过的参数 |--------定义路径为:checkpoint_dir/srcnn_21 |--------ckpt=tf.train.get_checkpoint_state(checkpoint_dir) #从文件夹中获取 checkpoint 文件 |--------ckpt_name = os.path.basename(ckpt.model_checkpoint_path) #获取 checkpoint 文件的文件名 |--------saver.restore(…) #重载模型的参数,继续训练或者用于测试数据 ------result = self.pred.eval({self.images: train_data, self.labels: train_label}) #网络输出 ------result = merge(result, [nx, ny]) #将一个 batch 内的图片拼接到一起
------result = result.squeeze() #squeeze 去除维度为 1 的地方 ------image_path = os.path.join(os.getcwd(), config.sample_dir) #生成图片的保存地址 ------image_path = os.path.join(image_path, "test_image.png") #保存地址/图片名 ------imsave(result, image_path) #保存图片
分享到:
收藏