这将在0到9的每批数据上训练一个新的VAE。模型权重将被保存到./rnn/weights.h5
中。new_model标志提示脚本从头开始训练模型。
和VAE一样,如果文件夹中存在weights.h5并且没有指定--new_model标记,那么脚本将从该文件加载权重,并继续训练现有模型。通过这种方式,您可以迭代地批量训练RNN。
找不到RNN架构说明的可以去翻翻./rnn/arch.py文件,可能会让你小小开心一下。
第八步:训练控制器这是一个愉快的章节。
上面,我们已经用深度学习搭建了一个VAE,它可以把高维图像压缩成一个低维潜在空间;还搭好了一个RNN,可以预测潜在空间随着时间推移会发生怎样的变化。能走到这一步,是因为我们给VAE和RNN各自装备了一个由随机rollout data组成的训练数据集。
现在,我们要使用一种强化学习方法,依靠名为CMA-ES的进化算法来训练控制器。
输入向量有288维,输出向量是3维。于是,我们一共有288 x 3 1 (bias) = 867个参数需要训练。
首先,CMA-ES要为这867个参数,创建多个随机初始化副本,形成种群 (population) 。而后,这个算法会在环境中,测试种群中的每一个成员,记录平均分。像达尔文的自然选择一样,分数比较高的那些权重就会获得“繁衍”后代的资格。
敲下这个代码,给每个参数选择一个合适的值,就可以开始训练了:
如果没有显示器的话,就用这个代码: