0%

Keras 自定义 generator

问题引入

在做kaggle上的Facial Keypoints Detection时候,它输入是一些人脸的图片, 标签是类似眼角、鼻尖、嘴角等在图片上的坐标。

整个数据集的图片数量并不多,7000+,而且中间数据集其实可以分为两个部分, 一部分是质量较好的数据集,图片数2000+,剩下的就是另外一部分质量较差的。 实际上在训练的时候,只用了质量较好的这一部分,那么也就2000+的图片, 对于图片类的任务来说,样本太少了。

那么这里显然使用Data Augmentation是比较明智的选择,也就是将图片进行一下平移、镜像、对比度变化、旋转等操作,反正这些操作并不会影响到图片本身的性质,但是在神经网络看来又不会是重复样本。

所以,这里就需要使用generator,也就是流式的数据生成,因为不可能先把图片进行变换,保存下来再进行训练,因为一张图片可以变换成为无数张图片,所以需要一边使用原始图片来生成新图片,一边进行训练。

在Keras中,model.fit_generator()就是用来进行这种类型的训练的,它需要传入一个生成器,也就是python中的生成器。

注意到,在Keras中,提供了ImageDataGenerator这么一个类可以来进行图片的变换,其中有很多的功能。 但是,但是,但是,这里的任务不是分类,是回归,其中标签是眼角、鼻尖、嘴角等在图片上的坐标, 那么在图片进行变换的时候,当然这里坐标也需要跟着变,那么这里就得自己来写generator


generator是什么?

我现在只有简单的理解,它就是一个python的生成器,每一次返回一个batch的样本以及标签。

1
2
3
4
5
6
7
8
9
10
11
def my_generator(X, Y, batch_size=32, gray_change_range=30, cut_out_size=88):
indexs = list(range(X.shape[0]))
while True:
np.random.shuffle(indexs)
for i in range(0, len(indexs), batch_size):
ge_batch_x = np.empty((batch_size, cut_out_size, cut_out_size))
ge_batch_y = np.empty((batch_size, Y.shape[1]))

...

yield ge_batch_x, ge_batch_y

生成器方法就类似于上面的代码,其中yield是关键的地方,程序每次运行到它的时候, 就会将它后面的数据返回,下一次调用又接着向后运行。

另外,这里写成了一个死循环while True,因为model.fit_generator()在使用在个函数的时候, 并不会在每一个epoch之后重新调用,那么如果这时候generator自己结束了就会有问题。


Keras使用generator进行训练

这里使用model.fit_generator()来进行训练即可:

1
2
3
fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None,
validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10,
workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

函数声明中有很多可以选择的参数,具体的说明可以去查看官方的文档,里面介绍了每一个参数的说明, 通常我们只需要定义几个参数就够了。

1
model.fit_generator(my_generator(train_x, train_y), steps_per_epoch=train_x.shape[0] / 32, epochs=10)

上面的调用,参数中传入了生成器,多少个batch为一个epoch,训练多少个epoch。


关于参数设置多个workers的问题

关于workers的文档说明:

workers: Integer. Maximum number of processes to spin up when using process based threading. If unspecified, workers will default to 1. If 0, will execute the generator on the main thread.

我的理解,它是与数据生成有关的,应该是使用多个线程来生成数据。

由于有时候生成数据的过程可能较慢,它可能会拖慢整个训练的速度,所以需要使用多个线程来同时生成, 使得生成数据不会是训练时间的瓶颈。

但是,当我设置workers=2时,程序报错,提示generator already executing, google之后,大概确定了这个是没有线程安全而导致的问题。

我没有实际去解决,因为其实一个线程已经够用了。如果下次需要开启多个线程,解决方法可以参考,Proper way of making a data generator which can handle multiple workers #1638