SRCNN—train
--srcnn = SRCNN() #初始化
------参数初始化
------build_model() #定义模型的输入输出,超参数,网络结构,loss,并实例化一个 Saver 对象
--train(self, config)
------input_setup(self.sess, config) #将输入和输出图片切成小块保存成 HDF5 文件
|--------prepare_data(sess, dataset="Train") #返回图片地址列表,sess 无用
------------data_dir = os.path.join(os.getcwd(), dataset) #os.getcwd():得到当前文件路径
------------data = glob.glob(os.path.join(data_dir, "*.bmp")) #图片路径列表
|--------for i in xrange(len(data)): #对于每张图片(每张图片大小不一致)
------------preprocess(data[i], config.scale) #预处理图片,返回(低,高)分辨率图像对
----------image = imread(path, is_grayscale=True) #以 YCbCr 格式读取原始图像(默认为灰度)
----------label_ = modcrop(image, scale) #将图片规整到可以被 scale 整除的宽和高
----------image = image / 255. #无用,可删掉
----------label_ = label_ / 255. #归一化
----------scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False) #进行三次插值缩小 scale 倍
----------scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False) #进行三次插值放大回原有大小
------------滑动窗口裁剪:将低分辨率图像和高分辨率图像分别裁剪成 33*33 和 21*21 的小图,步长 stride=6。
------------重定义类型(reshape)后添加到输入和输出列表
|--------np.asarray() #将输入列表和输出列表中的数据转为 numpy 类型
|--------make_data(sess, arrdata, arrlabel) #制作 h5 文件,这个是产生训练数据的函数
------------savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5') #定义 h5 文件保存地址
------------hf.create_dataset(name,data) #建立一个名为 name(字符串),数据为 data(numpy 数组)的数据集
------train_data, train_label = read_data(data_dir) #获取 HDF5 格式数据集
------tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss) #建立 SGD 优化器
------tf.initialize_all_variables().run() #初始化所有参数
------self.load(self.checkpoint_dir) #加载训练过的参数
|--------定义路径为:checkpoint_dir/srcnn_21
|--------ckpt=tf.train.get_checkpoint_state(checkpoint_dir) #从文件夹中获取 checkpoint 文件
|--------ckpt_name = os.path.basename(ckpt.model_checkpoint_path) #获取 checkpoint 文件的文件名
|--------saver.restore(…) #重载模型的参数,继续训练或者用于测试数据
------for ep in xrange(config.epoch): #对于每次 epoch
|--------batch_idxs #计算一个 epoch 有多少波 batch,放在循坏外不香吗
|--------for idx in xrange(0, batch_idxs): #对每波 batch
------------取 batch_size 对训练数据
------------self.sess.run([self.train_op, self.loss]) #运行数据流至 loss 操作,并使用优化器进行优化
------------print(…)#每更新 10 次输出一次数据
------------self.save(config.checkpoint_dir, counter) #每更新 500 次保存一次数据
----------self.saver.save() #向文件夹中写入包含当前模型中所有可训练变量的 checkpoint 文件,之后可以使用 saver.restore()方法,重载模型的
参数,继续训练或者用于测试数据