diff --git a/README.md b/README.md index c50321f..80ee2ff 100644 --- a/README.md +++ b/README.md @@ -191,11 +191,11 @@ Define some hyper-parameters: ```python task = 'CartPole-v0' lr, epoch, batch_size = 1e-3, 10, 64 -train_num, test_num = 8, 100 +train_num, test_num = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 -step_per_epoch, collect_per_step = 1000, 8 +step_per_epoch, step_per_collect = 10000, 10 writer = SummaryWriter('log/dqn') # tensorboard is also supported! ``` @@ -232,8 +232,8 @@ Let's train it: ```python result = ts.trainer.offpolicy_trainer( - policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, - test_num, batch_size, + policy, train_collector, test_collector, epoch, step_per_epoch, step_per_collect, + test_num, batch_size, update_per_step=1 / step_per_collect, train_fn=lambda epoch, env_step: policy.set_eps(eps_train), test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index b3e1263..26ee8d2 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -284,7 +284,7 @@ policy.process_fn The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns. -Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as: +Take 2-step return DQN as an example. The 2-step return DQN compute each transition's return as: .. math:: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 361f79f..40e4a39 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -35,10 +35,10 @@ If you want to use the original ``gym.Env``: Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.DummyVectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows: (more explanation can be found at :ref:`parallel_sampling`) :: - train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)]) + train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)]) -Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``. +Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``. For the demonstration, here we use the second code-block. @@ -87,7 +87,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour net = Net(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) -It is also possible to use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: +You can also use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. 2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. @@ -113,7 +113,7 @@ The collector is a key concept in Tianshou. It allows the policy to interact wit In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer. :: - train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 8), exploration_noise=True) + train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True) test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) @@ -125,8 +125,8 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, - max_epoch=10, step_per_epoch=1000, collect_per_step=10, - episode_per_test=100, batch_size=64, + max_epoch=10, step_per_epoch=10000, step_per_collect=10, + update_per_step=0.1, episode_per_test=100, batch_size=64, train_fn=lambda epoch, env_step: policy.set_eps(0.1), test_fn=lambda epoch, env_step: policy.set_eps(0.05), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, @@ -136,8 +136,8 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`): * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; -* ``step_per_epoch``: The number of step for updating policy network in one epoch; -* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; +* ``step_per_epoch``: The number of environment step (a.k.a. transition) collected per epoch; +* ``step_per_collect``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. * ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". @@ -205,7 +205,7 @@ Train a Policy with Customized Codes Tianshou supports user-defined training code. Here is the code snippet: :: - # pre-collect at least 5000 frames with random action before training + # pre-collect at least 5000 transitions with random action before training train_collector.collect(n_step=5000, random=True) policy.set_eps(0.1) diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index c656c1e..3d7f281 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -200,8 +200,9 @@ The explanation of each Tianshou class/function will be deferred to their first parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=500) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -293,7 +294,7 @@ With the above preparation, we are close to the first learned agent. The followi policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward: {result["rews"][:, args.agent_id - 1].mean()}, length: {result["lens"].mean()}') if args.watch: watch(args) exit(0) @@ -355,10 +356,10 @@ With the above preparation, we are close to the first learned agent. The followi # start training, this may require about three minutes result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, - writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, + writer=writer, test_in_train=False, reward_metric=reward_metric) agent = policy.policies[args.agent_id - 1] # let's watch the match! diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index e2b8f07..72b8d13 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=10000) + parser.add_argument("--update-per-epoch", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) @@ -140,7 +140,7 @@ def test_discrete_bcq(args=get_args()): result = offline_trainer( policy, buffer, test_collector, - args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, log_interval=args.log_interval, ) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 42cffa9..d0a7ab8 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -30,8 +30,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -141,9 +142,10 @@ def test_c51(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 559b087..077d248 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -27,8 +27,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -151,9 +152,10 @@ def test_dqn(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index ed35638..e2eed3c 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -28,8 +28,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -139,9 +140,10 @@ def test_qrdqn(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 0b81cec..5b760c1 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -4,7 +4,6 @@ import pprint import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter - from tianshou.policy import A2CPolicy from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net @@ -24,7 +23,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--episode-per-collect', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -91,8 +90,8 @@ def test_a2c(args=get_args()): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 8ed04c2..8a2a684 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--episode-per-collect', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -95,8 +95,8 @@ def test_ppo(args=get_args()): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 444c357..69d0bfb 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -25,8 +25,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=100) + parser.add_argument('--update-per-step', type=float, default=0.01) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128]) parser.add_argument('--dueling-q-hidden-sizes', type=int, @@ -103,8 +104,8 @@ def test_dqn(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 0bf802d..a903008 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -27,8 +27,9 @@ def get_args(): parser.add_argument('--auto-alpha', type=int, default=1) parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -143,9 +144,9 @@ def test_sac_bipedal(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, test_in_train=False, + stop_fn=stop_fn, save_fn=save_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index de88aa3..3d5d033 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=5000) - parser.add_argument('--collect-per-step', type=int, default=16) + parser.add_argument('--step-per-epoch', type=int, default=80000) + parser.add_argument('--step-per-collect', type=int, default=16) + parser.add_argument('--update-per-step', type=float, default=0.0625) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -99,10 +100,9 @@ def test_dqn(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, + test_fn=test_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 14e5095..333dab4 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -29,8 +29,9 @@ def get_args(): parser.add_argument('--auto_alpha', type=int, default=1) parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=12000) + parser.add_argument('--step-per-collect', type=int, default=5) + parser.add_argument('--update-per-step', type=float, default=0.2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -112,8 +113,10 @@ def test_sac(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) + assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index ba9cd79..22dda7d 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -28,8 +28,9 @@ def get_args(): parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--n-step', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--step-per-epoch', type=int, default=40000) + parser.add_argument('--step-per-collect', type=int, default=4) + parser.add_argument('--update-per-step', type=float, default=0.25) parser.add_argument('--update-per-step', type=int, default=1) parser.add_argument('--pre-collect-step', type=int, default=10000) parser.add_argument('--batch-size', type=int, default=256) @@ -139,10 +140,9 @@ def test_sac(args=get_args()): train_collector.collect(n_step=args.pre_collect_step, random=True) result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, args.update_per_step, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - log_interval=args.log_interval) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step, log_interval=args.log_interval) pprint.pprint(result) watch() diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index b9a6e01..a83abc3 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--step-per-collect', type=int, default=4) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -87,7 +87,7 @@ def test_ddpg(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index 2f83702..1bda7aa 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -96,7 +96,7 @@ def test_td3(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index 6f34ce0..d64155d 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=200) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -97,7 +97,7 @@ def test_sac(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, log_interval=args.log_interval) assert stop_fn(result['best_reward']) diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index 76271f4..f23e1b0 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -104,7 +104,7 @@ def test_td3(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 68d3fc4..afe88e1 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--step-per-epoch', type=int, default=9600) + parser.add_argument('--step-per-collect', type=int, default=4) + parser.add_argument('--update-per-step', type=float, default=0.25) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -102,8 +103,9 @@ def test_ddpg(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 45a59f4..4f8ede1 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -23,8 +23,8 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=16) + parser.add_argument('--step-per-epoch', type=int, default=150000) + parser.add_argument('--episode-per-collect', type=int, default=16) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, @@ -121,8 +121,8 @@ def test_ppo(args=get_args()): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 6e075bb..0a96dbf 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -26,8 +26,10 @@ def get_args(): parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=24000) + parser.add_argument('--il-step-per-epoch', type=int, default=500) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -110,8 +112,9 @@ def test_sac_with_il(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -142,7 +145,7 @@ def test_sac_with_il(args=get_args()): train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch // 5, args.collect_per_step, args.test_num, + args.il_step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index c90a92a..bbc32d9 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -29,8 +29,9 @@ def get_args(): parser.add_argument('--noise-clip', type=float, default=0.5) parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=20000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -115,8 +116,9 @@ def test_td3(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index ea7e6b6..1032b31 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -23,8 +23,11 @@ def get_args(): parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--il-step-per-epoch', type=int, default=1000) + parser.add_argument('--episode-per-collect', type=int, default=8) + parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -96,8 +99,8 @@ def test_a2c_with_il(args=get_args()): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -121,13 +124,12 @@ def test_a2c_with_il(args=get_args()): il_policy = ImitationPolicy(net, optim, mode='discrete') il_test_collector = Collector( il_policy, - DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) ) train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.il_step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 684ce96..fe0573e 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -28,8 +28,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--step-per-epoch', type=int, default=8000) + parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -112,7 +113,7 @@ def test_c51(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index df02684..e9104c1 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -25,9 +25,10 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -114,9 +115,9 @@ def test_dqn(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 420f8e6..dc2e06c 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=3) parser.add_argument('--training-num', type=int, default=10) @@ -92,9 +93,10 @@ def test_drqn(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, update_per_step=args.update_per_step, + train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 3dd2b8d..09c4c52 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--update-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128]) @@ -91,7 +91,7 @@ def test_discrete_bcq(args=get_args()): result = offline_trainer( policy, buffer, test_collector, - args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index a413111..784ae70 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -21,8 +21,8 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--step-per-epoch', type=int, default=40000) + parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -82,8 +82,8 @@ def test_pg(args=get_args()): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 9862ea7..35634c6 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -22,8 +22,8 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=2000) - parser.add_argument('--collect-per-step', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--episode-per-collect', type=int, default=20) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -108,8 +108,8 @@ def test_ppo(args=get_args()): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 006dd82..e5ce61b 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -110,9 +111,10 @@ def test_qrdqn(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 16ab54c..ebcb751 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -27,8 +27,9 @@ def get_args(): parser.add_argument('--alpha', type=float, default=0.05) parser.add_argument('--auto_alpha', type=int, default=0) parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=5) + parser.add_argument('--update-per-step', type=float, default=0.2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -108,9 +109,9 @@ def test_discrete_sac(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + update_per_step=args.update_per_step, test_in_train=False) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 5a81387..01ea98a 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -17,8 +17,8 @@ def get_args(): parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=50000) parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=5) - parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--episode-per-collect', type=int, default=1) parser.add_argument('--training-num', type=int, default=1) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -78,8 +78,8 @@ def test_psrl(args=get_args()): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, 1, - args.test_num, 0, stop_fn=stop_fn, writer=writer, + args.step_per_epoch, 1, args.test_num, 0, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer, test_in_train=False) if __name__ == '__main__': diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index b081a50..edf066e 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -28,8 +28,9 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=500) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -162,10 +163,10 @@ def train_agent( # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, - writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, + writer=writer, test_in_train=False, reward_metric=reward_metric) return result, policy.policies[args.agent_id - 1] @@ -183,4 +184,4 @@ def watch( collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b24e8c6..477a253 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -605,7 +605,7 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage class VectorReplayBuffer(ReplayBufferManager): """VectorReplayBuffer contains n ReplayBuffer with the same size. - It is used for storing data frame from different environments yet keeping the order + It is used for storing transition from different environments yet keeping the order of time. :param int total_size: the total size of VectorReplayBuffer. @@ -631,7 +631,7 @@ class VectorReplayBuffer(ReplayBufferManager): class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size. - It is used for storing data frame from different environments yet keeping the order + It is used for storing transition from different environments yet keeping the order of time. :param int total_size: the total size of PrioritizedVectorReplayBuffer. diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 74ee72d..bb3239e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -198,7 +198,7 @@ class Collector(object): if not n_step % self.env_num == 0: warnings.warn( f"n_step={n_step} is not a multiple of #env ({self.env_num}), " - "which may cause extra frame collected into the buffer." + "which may cause extra transitions collected into the buffer." ) ready_env_ids = np.arange(self.env_num) elif n_episode is not None: @@ -357,9 +357,9 @@ class AsyncCollector(Collector): ) -> Dict[str, Any]: """Collect a specified number of step or episode with async env setting. - This function doesn't collect exactly n_step or n_episode number of frames. - Instead, in order to support async setting, it may collect more than given - n_step or n_episode frames and save into buffer. + This function doesn't collect exactly n_step or n_episode number of + transitions. Instead, in order to support async setting, it may collect more + than given n_step or n_episode transitions and save into buffer. :param int n_step: how many steps you want to collect. :param int n_episode: how many episodes you want to collect. @@ -395,7 +395,7 @@ class AsyncCollector(Collector): else: raise TypeError("Please specify at least one (either n_step or n_episode) " "in AsyncCollector.collect().") - warnings.warn("Using async setting may collect extra frames into buffer.") + warnings.warn("Using async setting may collect extra transitions into buffer.") ready_env_ids = self._ready_env_ids diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index be6d821..9023bf4 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -220,21 +220,17 @@ class BasePolicy(ABC, nn.Module): Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q function/reward to go of given batch. - :param batch: a data batch which contains several episodes of data + :param Batch batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recongized by buffer.unfinished_index(). - :type batch: :class:`~tianshou.data.Batch` - :param numpy.ndarray indice: tell batch's location in buffer, batch is + :param np.ndarray indice: tell batch's location in buffer, batch is equal to buffer[indice]. - :param v_s_: the value function of all next states :math:`V(s')`. - :type v_s_: numpy.ndarray - :param float gamma: the discount factor, should be in [0, 1], defaults - to 0.99. - :param float gae_lambda: the parameter for Generalized Advantage - Estimation, should be in [0, 1], defaults to 0.95. - :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to False. + :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. + :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param float gae_lambda: the parameter for Generalized Advantage Estimation, + should be in [0, 1]. Default to 0.95. + :param bool rew_norm: normalize the reward to Normal(0, 1). Default to False. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). @@ -273,18 +269,14 @@ class BasePolicy(ABC, nn.Module): where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. - :param batch: a data batch, which is equal to buffer[indice]. - :type batch: :class:`~tianshou.data.Batch` - :param buffer: the data buffer. - :type buffer: :class:`~tianshou.data.ReplayBuffer` + :param Batch batch: a data batch, which is equal to buffer[indice]. + :param ReplayBuffer buffer: the data buffer. :param function target_q_fn: a function which compute target Q value of "obs_next" given data buffer and wanted indices. - :param float gamma: the discount factor, should be in [0, 1], defaults - to 0.99. - :param int n_step: the number of estimation step, should be an int - greater than 0, defaults to 1. - :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to False. + :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param int n_step: the number of estimation step, should be an int greater + than 0. Default to 1. + :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn's return tensor. diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 92b7f5d..82fb9f7 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -45,7 +45,7 @@ class PGPolicy(BasePolicy): def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - r"""Compute the discounted returns for each frame. + r"""Compute the discounted returns for each transition. .. math:: G_t = \sum_{i=t}^T \gamma^{i-t}r_i diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 13f9159..9c7f132 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -38,5 +38,5 @@ class RandomPolicy(BasePolicy): return Batch(act=logits.argmax(axis=-1)) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - """Since a random agent learn nothing, it returns an empty dict.""" + """Since a random agent learns nothing, it returns an empty dict.""" return {} diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 7eb6ec5..61714f7 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -16,7 +16,7 @@ def offline_trainer( buffer: ReplayBuffer, test_collector: Collector, max_epoch: int, - step_per_epoch: int, + update_per_epoch: int, episode_per_test: int, batch_size: int, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, @@ -29,50 +29,52 @@ def offline_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. - The "step" in trainer means a policy network update. + The "step" in offline trainer means a gradient step. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int update_per_epoch: the number of policy network updates, so-called gradient steps, per epoch. :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; + if None is given, it will not write logs to TensorBoard. Default to None. + :param int log_interval: the log interval of the writer. Default to 1. + :param bool verbose: whether to print the information. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ gradient_step = 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() test_collector.reset_stat() - + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, gradient_step, reward_metric) + best_epoch = 0 + best_reward = test_result["rews"].mean() + best_reward_std = test_result["rews"].std() for epoch in range(1, 1 + max_epoch): policy.train() with tqdm.trange( - step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config ) as t: for i in t: gradient_step += 1 @@ -87,16 +89,18 @@ def offline_trainer( global_step=gradient_step) t.set_postfix(**data) # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, gradient_step, reward_metric) - if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result['rews'].std() + test_result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, gradient_step, + reward_metric) + if best_epoch == -1 or best_reward < test_result["rews"].mean(): + best_reward = test_result["rews"].mean() + best_reward_std = test_result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " + f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index ba08c2e..54e7cb1 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -17,10 +17,10 @@ def offpolicy_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - collect_per_step: int, + step_per_collect: int, episode_per_test: int, batch_size: int, - update_per_step: int = 1, + update_per_step: Union[int, float] = 1, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, @@ -33,60 +33,62 @@ def offpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. - The "step" in trainer means a policy network update. + The "step" in trainer means an environment step (a.k.a. transition). :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param train_collector: the collector used for training. - :type train_collector: :class:`~tianshou.data.Collector` - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param int collect_per_step: the number of frames the collector would - collect before the network update. In other words, collect some frames - and do some policy network update. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatly in each epoch. :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param int update_per_step: the number of times the policy network would - be updated after frames are collected, for example, set it to 256 means - it updates policy 256 times once after ``collect_per_step`` frames are - collected. - :param function train_fn: a hook called at the beginning of training in - each epoch. It can be used to perform custom additional operations, - with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int/float update_per_step: the number of times the policy network would be + updated per transition after (step_per_collect) transitions are collected, + e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will + be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are + collected by the collector. Default to 1. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy:BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. - :param bool test_in_train: whether to test in the training phase. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; + if None is given, it will not write logs to TensorBoard. Default to None. + :param int log_interval: the log interval of the writer. Default to 1. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ env_step, gradient_step = 0, 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, env_step, reward_metric) + best_epoch = 0 + best_reward = test_result["rews"].mean() + best_reward_std = test_result["rews"].std() for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -96,10 +98,11 @@ def offpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_step=collect_per_step) + result = train_collector.collect(n_step=step_per_collect) if len(result["rews"]) > 0 and reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) + t.update(result["n/st"]) data = { "env_step": str(env_step), "rew": f"{result['rews'].mean():.2f}", @@ -126,8 +129,7 @@ def offpolicy_trainer( test_result["rews"].mean(), test_result["rews"].std()) else: policy.train() - for i in range(update_per_step * min( - result["n/st"] // collect_per_step, t.total - t.n)): + for i in range(round(update_per_step * result["n/st"])): gradient_step += 1 losses = policy.update(batch_size, train_collector.buffer) for k in losses.keys(): @@ -136,21 +138,21 @@ def offpolicy_trainer( if writer and gradient_step % log_interval == 0: writer.add_scalar( k, stat[k].get(), global_step=gradient_step) - t.update(1) t.set_postfix(**data) if t.n <= t.total: t.update() # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step, reward_metric) - if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() + test_result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, env_step, reward_metric) + if best_epoch == -1 or best_reward < test_result["rews"].mean(): + best_reward = test_result["rews"].mean() + best_reward_std = test_result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " + f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index b951f9a..43fcc87 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -17,10 +17,11 @@ def onpolicy_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - collect_per_step: int, repeat_per_collect: int, episode_per_test: int, batch_size: int, + step_per_collect: Optional[int] = None, + episode_per_collect: Optional[int] = None, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, @@ -33,60 +34,67 @@ def onpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. - The "step" in trainer means a policy network update. + The "step" in trainer means an environment step (a.k.a. transition). :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param train_collector: the collector used for training. - :type train_collector: :class:`~tianshou.data.Collector` - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param int collect_per_step: the number of episodes the collector would - collect before the network update. In other words, collect some - episodes and do one policy network update. - :param int repeat_per_collect: the number of repeat time for policy - learning, for example, set it to 2 means the policy needs to learn each - given batch data twice. - :param episode_per_test: the number of episodes for one policy evaluation. - :type episode_per_test: int or list of ints - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param function train_fn: a hook called at the beginning of training in - each epoch. It can be used to perform custom additional operations, - with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int repeat_per_collect: the number of repeat time for policy learning, for + example, set it to 2 means the policy needs to learn each given batch data + twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatly in each epoch. + :param int episode_per_collect: the number of episodes the collector would collect + before the network update, i.e., trainer will collect "episode_per_collect" + episodes and do some policy network update repeatly in each epoch. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. - :param bool test_in_train: whether to test in the training phase. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; + if None is given, it will not write logs to TensorBoard. Default to None. + :param int log_interval: the log interval of the writer. Default to 1. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. + + .. note:: + + Only either one of step_per_collect and episode_per_collect can be specified. """ env_step, gradient_step = 0, 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, env_step, reward_metric) + best_epoch = 0 + best_reward = test_result["rews"].mean() + best_reward_std = test_result["rews"].std() for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -96,10 +104,12 @@ def onpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_episode=collect_per_step) + result = train_collector.collect(n_step=step_per_collect, + n_episode=episode_per_collect) if reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) + t.update(result["n/st"]) data = { "env_step": str(env_step), "rew": f"{result['rews'].mean():.2f}", @@ -138,21 +148,21 @@ def onpolicy_trainer( if writer and gradient_step % log_interval == 0: writer.add_scalar( k, stat[k].get(), global_step=gradient_step) - t.update(step) t.set_postfix(**data) if t.n <= t.total: t.update() # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step) - if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() + test_result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, env_step) + if best_epoch == -1 or best_reward < test_result["rews"].mean(): + best_reward = test_result["rews"].mean() + best_reward_std = test_result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " + f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 2cdeb15..72803be 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -48,14 +48,14 @@ def gather_info( * ``train_step`` the total collected step of training collector; * ``train_episode`` the total collected episode of training collector; - * ``train_time/collector`` the time for collecting frames in the \ + * ``train_time/collector`` the time for collecting transitions in the \ training collector; * ``train_time/model`` the time for training models; - * ``train_speed`` the speed of training (frames per second); + * ``train_speed`` the speed of training (env_step per second); * ``test_step`` the total collected step of test collector; * ``test_episode`` the total collected episode of test collector; * ``test_time`` the time for testing; - * ``test_speed`` the speed of testing (frames per second); + * ``test_speed`` the speed of testing (env_step per second); * ``best_reward`` the best reward over the test results; * ``duration`` the total elapsed time. """