首页 > 编程知识 正文

OpenAI baseline GAIL代码讲解及其可视化,单片机led闪烁代码讲解

时间:2023-05-04 07:23:58 阅读:255484 作者:3688

最近在研究关于强化学习的部分工作,首先从OpenAI的Baseline中的小型GAIL算法出发。

首先参考了大神的文章从《西部世界》到GAIL(Generative Adversarial Imitation Learning)算法。

原文链接:https://blog.csdn.net/jinzhuojun/article/details/85220327#commentBox

对大神写的文章做一些补充和细节解释。

在baseline 的文件夹中运行即可以进行模型的训练

python3 -m baselines.gail.run_mujoco

在run_mujoco.py代码中写到

parser.add_argument('--task', type=str, choices=['train', 'evaluate', 'sample'], default='train')

 可以在命令行后面添加 --task 改变任务为train 和evaluate。evaluate后面要加上存储的模型的地址

# 假设训练模型放在/home/jzj/source/baselines/checkpoint/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0/python3 -m baselines.gail.run_mujoco --task=evaluate --load_model_path=/home/jzj/source/baselines/checkpoint/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0

在baseline 中使用tensorflow方式存储模型:  在trpo_mpi.py  232行。

# Save model if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) #U.save_variables(fname) #print("the save path is ",fname) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname)

所以在checkpoint中存储了可以用tensorflow方式读取模型的三个文件,而在运行评估模型时读取模型的方式采用的是baseline 中common自己定义的    U.load_variables(load_model_path)来读取文件,读取文件的类型是上面由tensorflow生成的文件的集合体。

U.load_variables(load_model_path)

因此在存储模型的时候也应该采用common中的定义的save_variables来存储模型生成集成文件:

if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None: fname = os.path.join(ckpt_dir, task_name) U.save_variables(fname) print("the save path is ",fname) os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(tf.get_default_session(), fname)

 然后运行train的命令行,在训练100次迭代之后就可以在保存模型的文件夹中发现一个无.data/.index/.meta后缀的集成文件。

此时再运行evaluate命令行就可以出现对模型的评估返回数据

在run_mujoco.py中的traj_1_generator函数中的while函数中插入env.render()就可以渲染出模型可视化结果。

while True: ac, vpred = pi.act(stochastic, ob) obs.append(ob) news.append(new) acs.append(ac) ob, rew, new, _ = env.step(ac) rews.append(rew) env.render() cur_ep_ret += rew cur_ep_len += 1 if new or t >= horizon: break t += 1

感谢大佬的分享,同时在遇到困难的时候还是要敢于挑战权威呀。

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。