diff --git a/.gitignore b/.gitignore index f703e0e..4215253 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,5 @@ flake8.sh log/ MUJOCO_LOG.TXT *.pth +.vscode/ +.DS_Store diff --git a/docs/_static/css/style.css b/docs/_static/css/style.css index 3e31531..6da8faf 100644 --- a/docs/_static/css/style.css +++ b/docs/_static/css/style.css @@ -112,6 +112,10 @@ footer p { font-size: 100%; } +.ethical-rtd { + display: none; +} + /* For hidden headers that appear in TOC tree */ /* see http://stackoverflow.com/a/32363545/3343043 */ .rst-content .hidden-section { diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 1972c75..7ed8843 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -20,7 +20,6 @@ else: # pytest def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--run-id', type=str, default='test') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=1e-4) @@ -84,9 +83,12 @@ def test_ddpg(args=get_args()): policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) # log - log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id) + log_path = os.path.join(args.logdir, args.task, 'ddpg') writer = SummaryWriter(log_path) + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + def stop_fn(x): return x >= env.spec.reward_threshold @@ -94,7 +96,7 @@ def test_ddpg(args=get_args()): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer) + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 6bbbdc5..e1877fb 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -20,7 +20,6 @@ else: # pytest def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--run-id', type=str, default='test') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=3e-4) @@ -92,9 +91,12 @@ def _test_ppo(args=get_args()): test_collector = Collector(policy, test_envs) train_collector.collect(n_step=args.step_per_epoch) # log - log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id) + log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + def stop_fn(x): return x >= env.spec.reward_threshold @@ -102,7 +104,8 @@ def _test_ppo(args=get_args()): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/test/continuous/test_sac.py b/test/continuous/test_sac.py index 2d5b3df..16844d7 100644 --- a/test/continuous/test_sac.py +++ b/test/continuous/test_sac.py @@ -20,7 +20,6 @@ else: # pytest def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--run-id', type=str, default='test') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=3e-4) @@ -89,9 +88,12 @@ def test_sac(args=get_args()): test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log - log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) + log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + def stop_fn(x): return x >= env.spec.reward_threshold @@ -99,7 +101,7 @@ def test_sac(args=get_args()): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer) + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index e7e82c9..3cda0bd 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -20,7 +20,6 @@ else: # pytest def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--run-id', type=str, default='test') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=3e-4) @@ -93,9 +92,12 @@ def test_td3(args=get_args()): test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log - log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id) + log_path = os.path.join(args.logdir, args.task, 'td3') writer = SummaryWriter(log_path) + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + def stop_fn(x): return x >= env.spec.reward_threshold @@ -103,7 +105,7 @@ def test_td3(args=get_args()): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer) + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/test/discrete/test_a2c.py b/test/discrete/test_a2c.py index f416c90..93cfe11 100644 --- a/test/discrete/test_a2c.py +++ b/test/discrete/test_a2c.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -76,7 +77,11 @@ def test_a2c(args=get_args()): policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + 'a2c') + log_path = os.path.join(args.logdir, args.task, 'a2c') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(x): return x >= env.spec.reward_threshold @@ -85,7 +90,8 @@ def test_a2c(args=get_args()): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index bb712a4..2816613 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -74,7 +75,11 @@ def test_dqn(args=get_args()): # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) # log - writer = SummaryWriter(args.logdir + '/' + 'dqn') + log_path = os.path.join(args.logdir, args.task, 'dqn') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(x): return x >= env.spec.reward_threshold @@ -90,7 +95,7 @@ def test_dqn(args=get_args()): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 6bbc04b..2be8a05 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -78,7 +79,11 @@ def test_drqn(args=get_args()): # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) # log - writer = SummaryWriter(args.logdir + '/' + 'dqn') + log_path = os.path.join(args.logdir, args.task, 'drqn') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(x): return x >= env.spec.reward_threshold @@ -94,7 +99,7 @@ def test_drqn(args=get_args()): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 081f48a..0293b07 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -1,3 +1,4 @@ +import os import gym import time import torch @@ -125,7 +126,11 @@ def test_pg(args=get_args()): policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + 'pg') + log_path = os.path.join(args.logdir, args.task, 'pg') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(x): return x >= env.spec.reward_threshold @@ -134,7 +139,8 @@ def test_pg(args=get_args()): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 0ba1f63..28ca0e9 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -81,7 +82,11 @@ def test_ppo(args=get_args()): policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + 'ppo') + log_path = os.path.join(args.logdir, args.task, 'ppo') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(x): return x >= env.spec.reward_threshold @@ -90,7 +95,8 @@ def test_ppo(args=get_args()): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index f0e5d9b..596e45a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -49,7 +49,7 @@ class ReplayBuffer(object): >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) >>> for i in range(16): ... done = i % 5 == 0 - ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i, info={}) + ... buf.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1) >>> print(buf) ReplayBuffer( obs: [ 9. 10. 11. 12. 13. 14. 15. 7. 8.], @@ -74,6 +74,17 @@ class ReplayBuffer(object): >>> # (stack only for obs and obs_next) >>> abs(buf.get(index, 'obs') - buf[index].obs).sum().sum() 0.0 + >>> # we can get obs_next through __getitem__, even if it doesn't store + >>> print(buf[:].obs_next) + [[ 7. 8. 9. 10.] + [ 7. 8. 9. 10.] + [11. 11. 11. 12.] + [11. 11. 12. 13.] + [11. 12. 13. 14.] + [12. 13. 14. 15.] + [12. 13. 14. 15.] + [ 7. 7. 7. 8.] + [ 7. 7. 8. 9.]] """ def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs): diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 3db6269..0d57a37 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, step_per_epoch, collect_per_step, episode_per_test, batch_size, - train_fn=None, test_fn=None, stop_fn=None, log_fn=None, - writer=None, log_interval=1, verbose=True, task='', - **kwargs): + train_fn=None, test_fn=None, stop_fn=None, save_fn=None, + log_fn=None, writer=None, log_interval=1, verbose=True, + task='', **kwargs): """A wrapper for off-policy trainer procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` @@ -35,6 +35,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, :param function test_fn: a function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. + :param function save_fn: a function for saving policy when the undiscounted + average mean reward in evaluation phase gets better. :param function stop_fn: a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. @@ -66,6 +68,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, policy, test_collector, test_fn, epoch, episode_per_test) if stop_fn and stop_fn(test_result['rew']): + if save_fn: + save_fn(policy) for k in result.keys(): data[k] = f'{result[k]:.2f}' t.set_postfix(**data) @@ -105,6 +109,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, if best_epoch == -1 or best_reward < result['rew']: best_reward = result['rew'] best_epoch = epoch + if save_fn: + save_fn(policy) if verbose: print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ' f'best_reward: {best_reward:.6f} in #{best_epoch}') diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index a973089..1ae846b 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -8,9 +8,9 @@ from tianshou.trainer import test_episode, gather_info def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, step_per_epoch, collect_per_step, repeat_per_collect, episode_per_test, batch_size, - train_fn=None, test_fn=None, stop_fn=None, log_fn=None, - writer=None, log_interval=1, verbose=True, task='', - **kwargs): + train_fn=None, test_fn=None, stop_fn=None, save_fn=None, + log_fn=None, writer=None, log_interval=1, verbose=True, + task='', **kwargs): """A wrapper for on-policy trainer procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` @@ -39,6 +39,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, :param function test_fn: a function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. + :param function save_fn: a function for saving policy when the undiscounted + average mean reward in evaluation phase gets better. :param function stop_fn: a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. @@ -70,6 +72,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, policy, test_collector, test_fn, epoch, episode_per_test) if stop_fn and stop_fn(test_result['rew']): + if save_fn: + save_fn(policy) for k in result.keys(): data[k] = f'{result[k]:.2f}' t.set_postfix(**data) @@ -113,6 +117,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, if best_epoch == -1 or best_reward < result['rew']: best_reward = result['rew'] best_epoch = epoch + if save_fn: + save_fn(policy) if verbose: print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ' f'best_reward: {best_reward:.6f} in #{best_epoch}')