TensorFlow (4) - 保存/恢复/提前终止

本章主要说明如何对训练过程进行存储,恢复操作。以及在模型长期没有改进时,提前终止训练。

TensorFlow (4) - 保存/恢复/提前终止

这是几篇与原作不完全相同的教程。

转载请说明出处:TensorFlow (4) - 保存/恢复/提前终止

原作者:Magnus Erik Hvass Pedersen / GitHub / Videos on YouTube

第三章 中,我们使用 TensorFlow 的 API 重新构建了网络,并且训练了 10000 轮次得到了 98.7% 的测试准确率。然而,在实际训练过程中,模型的结构往往比这更加复杂,数据量也更多,训练 10000 轮次需要较长的时间。长时间的训练存在多个问题:

  • 如果不保存训练过程中的状态,在程序运行完后,整个训练过程被销毁,无法运用到新的数据集上,这样的模型没有意义。
  • 如果在训练过程中一旦出现突发情况停止,之前的训练将前功尽弃,必须从头开始训练。
  • 如果经过多次的训练,模型的性能不再提升,那么这些训练是没有意义的,如果不提前终止,将浪费大量时间。

事实上,TensorFlow 提供了保存和恢复训练的方法,可以避免这类事情的发生。在本章中,我们将继续使用 第三章 的大部分代码,只做小部分的修改,以实现我们的目的。

载入数据,构建模型

这一部分与第三章相似,但是我们在本章中使用验证集来验证模型性能,因此需要稍作改动。此外,需要用到上一章给出的 cnn_helper.py

from cnn_helper import *

# notebook使用
%load_ext autoreload  
%autoreload 2
%matplotlib inline

数据

from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('data/MNIST/', one_hot=True)

print()
print("数据集大小:")
print('- 训练集:{}'.format(len(data.train.labels)))
print('- 测试集:{}'.format(len(data.test.labels)))
print('- 验证集:{}'.format(len(data.validation.labels)))
数据集大小:
- 训练集:55000
- 测试集:10000
- 验证集:5000
data.test.cls = np.argmax(data.test.labels, axis=1)
data.validation.cls = np.argmax(data.validation.labels, axis=1)

神经网络参数

img_size = 28                        # 图片的高度和宽度
img_size_flat = img_size * img_size  # 展平为向量的尺寸
img_shape = (img_size, img_size)     # 图片的二维尺寸

num_channels = 1                     # 输入为单通道灰度图像
num_classes = 10                     # 类别数目

# 卷积层 1
filter_size1 = 5          # 5 x 5 卷积核
num_filters1 = 16         # 共 16 个卷积核

# 卷积层 2
filter_size2 = 5          # 5 x 5 卷积核
num_filters2 = 36         # 共 36 个卷积核

# 全连接层
fc_size = 128             # 全连接层神经元数

占位符

x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')          # 原始输入
x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])                # 转换为2维图像
y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')  # 原始输出
y_true_cls = tf.argmax(y_true, axis=1)                                         # 转换为真实类别

卷积神经网络

layer_conv1 = tf.layers.conv2d(inputs=x_image,            # 输入
                               filters=num_filters1,      # 卷积核个数
                               kernel_size=filter_size1,  # 卷积核尺寸
                               padding='same',            # padding方法
                               activation=tf.nn.relu,     # 激活函数relu
                               name='layer_conv1')        # 命名用于获取变量

net = tf.layers.max_pooling2d(inputs=layer_conv1, pool_size=2, strides=(2, 2),  padding='same')
layer_conv2 = tf.layers.conv2d(inputs=net,          
                               filters=num_filters2,
                               kernel_size=filter_size2,
                               padding='same',
                               activation=tf.nn.relu,
                               name='layer_conv2')

net = tf.layers.max_pooling2d(inputs=layer_conv2, pool_size=2, strides=(2, 2),  padding='same')
layer_flat = tf.contrib.layers.flatten(net)    # flatten暂时在tf.contrib一层

layer_fc1 = tf.layers.dense(inputs=layer_flat, units=fc_size, activation=tf.nn.relu, name='layer_fc1')
layer_fc2 = tf.layers.dense(inputs=layer_fc1, units=num_classes, name='layer_fc2')

代价、优化器、准确率

y_pred = tf.nn.softmax(layer_fc2)              # softmax归一化
y_pred_cls = tf.argmax(y_pred, axis=1)         # 真实类别

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=layer_fc2, labels=y_true)
cost = tf.reduce_mean(cross_entropy)

optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)

correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

获取权重

def get_weights_variable(layer_name):
    # 根据给定的layer_name,返回名为'kernel'的变量
    with tf.variable_scope(layer_name, reuse=True):
        variable = tf.get_variable('kernel')
    return variable

weights_conv1 = get_weights_variable(layer_name='layer_conv1')
weights_conv2 = get_weights_variable(layer_name='layer_conv2')

Saver

为了保存神经网络中的变量,我们需要创建一个 Saver 对象用来存储和检索 TensorFlow 计算图中的所有变量。我们可以保存训练过程中的所有结果,在这里仅保存最优的结果。

saver = tf.train.Saver()    # 用于保存变量

save_dir = 'checkpoints/'   # 保存目录
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

save_path = os.path.join(save_dir, 'best_validation')   # 最佳验证结果保存路径

运行TensorFlow

创建session并初始化

session = tf.Session()
session.run(tf.global_variables_initializer())

优化器的迭代过程

为了测试优化器的性能,需要多添加几个指标,对代码进行相应调整。

train_batch_size = 64

best_validation_accuracy = 0.0    # 当前最佳验证集准确率
last_improvement = 0     # 上一次有所改进的轮次
require_improvement = 1000    # 如果在1000轮内没有改进,停止迭代
# 计算目前执行的总迭代次数
total_iterations = 0

def optimize(num_iterations):
    # 保证更新全局变量.
    global total_iterations
    global best_validation_accuracy
    global last_improvement

    # 用来输出用时.
    start_time = time.time()

    for i in range(num_iterations):
        total_iterations += 1

        # 获取一批数据,运行优化器
        x_batch, y_true_batch = data.train.next_batch(train_batch_size)
        feed_dict_train = {x: x_batch, y_true: y_true_batch}
        session.run(optimizer, feed_dict=feed_dict_train)

        # 每100轮迭代输出状态
        if (total_iterations % 100 == 0) or (i == num_iterations - 1):
            # 计算训练集准确率.
            acc_train = session.run(accuracy, feed_dict=feed_dict_train)
            # 验证集准确率,为了尽可能重用代码,这个函数会在后面实现
            acc_validation, _ = validation_accuracy()

            if acc_validation > best_validation_accuracy:   # 如果当前验证集准确率大于之前的最好准确率
                best_validation_accuracy = acc_validation   # 更新最好准确率
                last_improvement = total_iterations         # 更新上一次提升的迭代轮次

                saver.save(sess=session, save_path=save_path)   # 将这一次更新保存下来
                improved_str = '*'    # 标注为找到提升
            else:
                improved_str = ''    

            msg = "迭代轮次: {0:>6}, 训练集准确率: {1:>6.1%}, 验证集准确率: {2:>6.1%} {3}"
            print(msg.format(i + 1, acc_train, acc_validation, improved_str))

        # 如果在require_improvement轮次内未有提升
        if total_iterations - last_improvement > require_improvement:
            print("长时间未提升, 停止优化。")
            break  # 跳出循环

    end_time = time.time()
    time_dif = end_time - start_time

    # 输出用时.
    print("用时: " + str(timedelta(seconds=int(round(time_dif)))))

计算分类性能

为了重用代码以评估验证集和测试集的性能,需要重构这部分代码:

batch_size = 256

def predict_cls(images, labels, cls_true):
    num_images = len(images)

    # 为预测结果申请一个数组
    cls_pred = np.zeros(shape=num_images, dtype=np.int)

    i = 0  # 数据集的起始id为0

    while i < num_images:
        # j为下一批次的截止id
        j = min(i + batch_size, num_images)

        # 创建feed_dict
        feed_dict = {x: images[i:j, :], y_true: labels[i:j, :]}

        # 计算预测结果
        cls_pred[i:j] = session.run(y_pred_cls, feed_dict=feed_dict)

        # 设定为下一批次起始值.
        i = j

    # 正确的分类
    correct = (cls_true == cls_pred)

    return correct, cls_pred
def predict_cls_test():    # 测试集分类结果
    return predict_cls(images = data.test.images,
                       labels = data.test.labels,
                       cls_true = data.test.cls)

def predict_cls_validation():   # 验证集分类结果
    return predict_cls(images = data.validation.images,
                       labels = data.validation.labels,
                       cls_true = data.validation.cls)

def cls_accuracy(correct):    # 准确率
    # 计算总的正确个数
    correct_sum = correct.sum()
    #计算准确率
    acc = float(correct_sum) / len(correct)
    return acc, correct_sum

def validation_accuracy():    # optimize()用到的验证集准确率函数
    correct, _ = predict_cls_validation()
    return cls_accuracy(correct)

测试集性能评估

def print_test_accuracy(show_example_errors=False,
                        show_confusion_matrix=False):

    # 返回测试集的正确情况与预测结果
    correct, cls_pred = predict_cls_test()

    # 计算准确率,准确数
    acc, num_correct = cls_accuracy(correct)

    num_images = len(correct)

    # 打印准确率.
    msg = "测试集准确率: {0:.1%} ({1} / {2})"
    print(msg.format(acc, num_correct, num_images))

    # 打印部分错误样例.
    if show_example_errors:
        print("Example errors:")
        plot_example_errors(data_test=data.test, cls_pred=cls_pred, correct=correct, img_shape=img_shape)

    # 打印混淆矩阵.
    if show_confusion_matrix:
        print("Confusion Matrix:")
        plot_confusion_matrix(cls_true=data.test.cls, cls_pred=cls_pred)

优化前性能:

print_test_accuracy()
测试集准确率: 13.5% (1348 / 10000)

执行 10000 轮迭代:

optimize(num_iterations=10000)
迭代轮次:    100, 训练集准确率:  81.2%, 验证集准确率:  78.9% *
迭代轮次:    200, 训练集准确率:  85.9%, 验证集准确率:  86.9% *
迭代轮次:    300, 训练集准确率:  93.8%, 验证集准确率:  90.9% *
迭代轮次:    400, 训练集准确率:  92.2%, 验证集准确率:  92.1% *
迭代轮次:    500, 训练集准确率:  93.8%, 验证集准确率:  93.0% *
迭代轮次:    600, 训练集准确率:  95.3%, 验证集准确率:  93.7% *
迭代轮次:    700, 训练集准确率:  90.6%, 验证集准确率:  94.3% *
迭代轮次:    800, 训练集准确率:  98.4%, 验证集准确率:  94.7% *
迭代轮次:    900, 训练集准确率:  96.9%, 验证集准确率:  95.4% *
迭代轮次:   1000, 训练集准确率:  96.9%, 验证集准确率:  95.5% *
迭代轮次:   1100, 训练集准确率:  93.8%, 验证集准确率:  95.7% *
迭代轮次:   1200, 训练集准确率:  98.4%, 验证集准确率:  96.0% *
迭代轮次:   1300, 训练集准确率:  96.9%, 验证集准确率:  95.8%
迭代轮次:   1400, 训练集准确率:  96.9%, 验证集准确率:  96.1% *
迭代轮次:   1500, 训练集准确率:  98.4%, 验证集准确率:  96.3% *
迭代轮次:   1600, 训练集准确率:  98.4%, 验证集准确率:  96.7% *
迭代轮次:   1700, 训练集准确率:  92.2%, 验证集准确率:  96.8% *
迭代轮次:   1800, 训练集准确率:  96.9%, 验证集准确率:  96.9% *
迭代轮次:   1900, 训练集准确率: 100.0%, 验证集准确率:  96.8%
迭代轮次:   2000, 训练集准确率:  98.4%, 验证集准确率:  97.0% *
迭代轮次:   2100, 训练集准确率: 100.0%, 验证集准确率:  97.0% *
迭代轮次:   2200, 训练集准确率:  98.4%, 验证集准确率:  97.4% *
迭代轮次:   2300, 训练集准确率:  96.9%, 验证集准确率:  97.3%
迭代轮次:   2400, 训练集准确率:  96.9%, 验证集准确率:  97.3%
迭代轮次:   2500, 训练集准确率:  96.9%, 验证集准确率:  97.5% *
迭代轮次:   2600, 训练集准确率:  98.4%, 验证集准确率:  97.3%
迭代轮次:   2700, 训练集准确率:  95.3%, 验证集准确率:  97.4%
迭代轮次:   2800, 训练集准确率: 100.0%, 验证集准确率:  97.7% *
迭代轮次:   2900, 训练集准确率:  98.4%, 验证集准确率:  97.7%
迭代轮次:   3000, 训练集准确率:  93.8%, 验证集准确率:  97.8% *
迭代轮次:   3100, 训练集准确率:  98.4%, 验证集准确率:  97.9% *
迭代轮次:   3200, 训练集准确率:  98.4%, 验证集准确率:  97.9%
迭代轮次:   3300, 训练集准确率:  96.9%, 验证集准确率:  97.9%
迭代轮次:   3400, 训练集准确率:  98.4%, 验证集准确率:  97.9%
迭代轮次:   3500, 训练集准确率:  96.9%, 验证集准确率:  97.8%
迭代轮次:   3600, 训练集准确率: 100.0%, 验证集准确率:  98.0% *
迭代轮次:   3700, 训练集准确率:  98.4%, 验证集准确率:  97.8%
迭代轮次:   3800, 训练集准确率: 100.0%, 验证集准确率:  97.9%
迭代轮次:   3900, 训练集准确率:  98.4%, 验证集准确率:  97.9%
迭代轮次:   4000, 训练集准确率: 100.0%, 验证集准确率:  97.7%
迭代轮次:   4100, 训练集准确率:  98.4%, 验证集准确率:  98.1% *
迭代轮次:   4200, 训练集准确率:  98.4%, 验证集准确率:  98.0%
迭代轮次:   4300, 训练集准确率: 100.0%, 验证集准确率:  97.9%
迭代轮次:   4400, 训练集准确率:  98.4%, 验证集准确率:  98.1% *
迭代轮次:   4500, 训练集准确率: 100.0%, 验证集准确率:  98.1%
迭代轮次:   4600, 训练集准确率: 100.0%, 验证集准确率:  97.9%
迭代轮次:   4700, 训练集准确率:  98.4%, 验证集准确率:  98.2% *
迭代轮次:   4800, 训练集准确率:  96.9%, 验证集准确率:  98.2% *
迭代轮次:   4900, 训练集准确率: 100.0%, 验证集准确率:  98.2%
迭代轮次:   5000, 训练集准确率: 100.0%, 验证集准确率:  98.3% *
迭代轮次:   5100, 训练集准确率: 100.0%, 验证集准确率:  98.3%
迭代轮次:   5200, 训练集准确率:  98.4%, 验证集准确率:  98.3%
迭代轮次:   5300, 训练集准确率: 100.0%, 验证集准确率:  98.4% *
迭代轮次:   5400, 训练集准确率:  98.4%, 验证集准确率:  98.3%
迭代轮次:   5500, 训练集准确率:  96.9%, 验证集准确率:  98.4% *
迭代轮次:   5600, 训练集准确率:  98.4%, 验证集准确率:  98.3%
迭代轮次:   5700, 训练集准确率:  98.4%, 验证集准确率:  98.4% *
迭代轮次:   5800, 训练集准确率:  98.4%, 验证集准确率:  98.4% *
迭代轮次:   5900, 训练集准确率:  96.9%, 验证集准确率:  98.4%
迭代轮次:   6000, 训练集准确率:  95.3%, 验证集准确率:  98.4%
迭代轮次:   6100, 训练集准确率: 100.0%, 验证集准确率:  98.3%
迭代轮次:   6200, 训练集准确率: 100.0%, 验证集准确率:  98.5% *
迭代轮次:   6300, 训练集准确率: 100.0%, 验证集准确率:  98.4%
迭代轮次:   6400, 训练集准确率: 100.0%, 验证集准确率:  98.6% *
迭代轮次:   6500, 训练集准确率: 100.0%, 验证集准确率:  98.4%
迭代轮次:   6600, 训练集准确率: 100.0%, 验证集准确率:  98.7% *
迭代轮次:   6700, 训练集准确率: 100.0%, 验证集准确率:  98.7% *
迭代轮次:   6800, 训练集准确率: 100.0%, 验证集准确率:  98.4%
迭代轮次:   6900, 训练集准确率: 100.0%, 验证集准确率:  98.6%
迭代轮次:   7000, 训练集准确率:  96.9%, 验证集准确率:  98.7%
迭代轮次:   7100, 训练集准确率:  98.4%, 验证集准确率:  98.4%
迭代轮次:   7200, 训练集准确率: 100.0%, 验证集准确率:  98.5%
迭代轮次:   7300, 训练集准确率:  98.4%, 验证集准确率:  98.6%
迭代轮次:   7400, 训练集准确率:  98.4%, 验证集准确率:  98.6%
迭代轮次:   7500, 训练集准确率:  98.4%, 验证集准确率:  98.4%
迭代轮次:   7600, 训练集准确率: 100.0%, 验证集准确率:  98.6%
迭代轮次:   7700, 训练集准确率: 100.0%, 验证集准确率:  98.7%
长时间未提升, 停止优化。
用时: 0:13:31

我们发现,在经过 6700 轮次后,由于性能不再提升,优化器提前停止了迭代,减少了部分的训练时间,而验证集的准确率达到了 98.7%。

print_test_accuracy(show_example_errors=True, show_confusion_matrix=True)
测试集准确率: 98.7% (9865 / 10000)
Example errors:
Confusion Matrix:
[[ 973    0    1    0    0    1    2    1    2    0]
 [   0 1127    2    0    0    0    2    1    3    0]
 [   0    1 1021    1    1    0    0    3    4    1]
 [   0    0    1  998    0    7    0    1    3    0]
 [   0    0    1    0  979    0    1    1    0    0]
 [   2    0    0    3    0  885    2    0    0    0]
 [   5    3    0    0    3    5  940    0    2    0]
 [   1    1    6    2    0    0    0 1017    1    0]
 [   3    0    4    2    1    2    0    3  957    2]
 [   2    4    1    4   13    6    0    7    4  968]]

打印权重

weights1 = session.run(weights_conv1)
plot_conv_weights(weights=weights1)

重新初始化

再次重新初始化所有的变量

session.run(tf.global_variables_initializer())    # 重新初始化

print_test_accuracy()     # 准确率降回随机
测试集准确率: 19.4% (1937 / 10000)

可见,输出降回了随机的情况。

weights1 = session.run(weights_conv1)   
plot_conv_weights(weights=weights1)     # 权重也与上面的权重不同

权重也与训练好的模型大不相同。

恢复路径下的变量

现在需要从变量所保存的路径下恢复所有的变量。

saver.restore(sess=session, save_path=save_path)   # 现在从保存的目录中重新载入所有的变量

再次计算准确率

# 再次打印测试准确率
print_test_accuracy(show_example_errors=True, show_confusion_matrix=True)
测试集准确率: 98.6% (9864 / 10000)
Example errors:
Confusion Matrix:
[[ 974    0    1    0    0    1    2    0    1    1]
 [   0 1126    3    0    0    0    2    1    3    0]
 [   0    0 1022    1    1    0    0    3    4    1]
 [   0    0    1 1002    0    3    0    1    2    1]
 [   0    0    1    0  980    0    1    0    0    0]
 [   2    0    0    4    0  882    2    0    0    2]
 [   3    2    0    0    3    6  942    0    2    0]
 [   1    3    7    3    0    0    0 1005    1    8]
 [   4    0    6    6    1    3    1    2  946    5]
 [   2    3    2    3    7    4    0    2    1  985]]

可以发现,测试集的准确率达到了 98.6%,与前面训练好的模型相差无几。

weights1 = session.run(weights_conv1)
plot_conv_weights(weights=weights1)

这一次权重和训练过后的就基本相同了,存在轻微不同的原因是原始的模型多训练了 1000 次。

在重新载入变量后,我们还可以继续优化这些变量。

关闭 session

session.close()