diff --git a/README.md b/README.md index 3600434..21e8cdb 100644 --- a/README.md +++ b/README.md @@ -229,9 +229,11 @@ Let's train it: ```python result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, - test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train), - test_fn=lambda e: policy.set_eps(eps_test), - stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task) + test_num, batch_size, + train_fn=lambda epoch, env_step: policy.set_eps(eps_train), + test_fn=lambda epoch, env_step: policy.set_eps(eps_test), + stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, + writer=writer, task=task) print(f'Finished training! Use {result["duration"]}') ``` diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index d923b56..49f6260 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -123,9 +123,9 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians policy, train_collector, test_collector, max_epoch=10, step_per_epoch=1000, collect_per_step=10, episode_per_test=100, batch_size=64, - train_fn=lambda e: policy.set_eps(0.1), - test_fn=lambda e: policy.set_eps(0.05), - stop_fn=lambda x: x >= env.spec.reward_threshold, + train_fn=lambda epoch, env_step: policy.set_eps(0.1), + test_fn=lambda epoch, env_step: policy.set_eps(0.05), + stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, writer=None) print(f'Finished training! Use {result["duration"]}') @@ -136,8 +136,8 @@ The meaning of each parameter is as follows (full description can be found at :m * ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. -* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". -* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". +* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". +* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". * ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. * ``writer``: See below. diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index cc4116d..a511af2 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -334,15 +334,15 @@ With the above preparation, we are close to the first learned agent. The followi policy.policies[args.agent_id - 1].state_dict(), model_save_path) - def stop_fn(x): - return x >= args.win_rate # 95% winning rate by default + def stop_fn(mean_rewards): + return mean_rewards >= args.win_rate # 95% winning rate by default # the default args.win_rate is 0.9, but the reward is [-1, 1] # instead of [0, 1], so args.win_rate == 0.9 is equal to 95% win rate. - def train_fn(x): + def train_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) - def test_fn(x): + def test_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) # start training, this may require about three minutes diff --git a/examples/atari/README.md b/examples/atari/README.md index 40c025c..474f74c 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -12,14 +12,14 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | time cost | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- | -| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) | -| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | -| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test_num 100` | 3~4h (100 epoch) | -| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | -| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | -| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | -| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | +| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch-size 64` | ~30 min (~15 epoch) | +| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test-num 100` | 3~4h (100 epoch) | +| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | -Note: The eps_train_final and eps_test in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. +Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. We haven't tuned this result to the best, so have fun with playing these hyperparameters! diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 92bc045..e448903 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -18,20 +18,20 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--eps_test', type=float, default=0.005) - parser.add_argument('--eps_train', type=float, default=1.) - parser.add_argument('--eps_train_final', type=float, default=0.05) + parser.add_argument('--eps-test', type=float, default=0.005) + parser.add_argument('--eps-train', type=float, default=1.) + parser.add_argument('--eps-train-final', type=float, default=0.05) parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--n_step', type=int, default=3) - parser.add_argument('--target_update_freq', type=int, default=500) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step_per_epoch', type=int, default=10000) - parser.add_argument('--collect_per_step', type=int, default=10) - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--training_num', type=int, default=16) - parser.add_argument('--test_num', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( @@ -95,26 +95,25 @@ def test_dqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): + def stop_fn(mean_rewards): if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold elif 'Pong' in args.task: - return x >= 20 + return mean_rewards >= 20 else: return False - def train_fn(x): + def train_fn(epoch, env_step): # nature DQN setting, linear decay in the first 1M steps - now = x * args.collect_per_step * args.step_per_epoch - if now <= 1e6: - eps = args.eps_train - now / 1e6 * \ + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * \ (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=now) + writer.add_scalar('train/eps', eps, global_step=env_step) - def test_fn(x): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # watch agent's performance diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index f4b0a30..290d11d 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -1,3 +1,4 @@ +import os import torch import pprint import argparse @@ -76,11 +77,11 @@ def test_a2c(args=get_args()): preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log - writer = SummaryWriter(args.logdir + '/' + 'a2c') + writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c')) - def stop_fn(x): + def stop_fn(mean_rewards): if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold else: return False diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 9d5563f..4e898c8 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -1,3 +1,4 @@ +import os import torch import pprint import argparse @@ -80,11 +81,11 @@ def test_ppo(args=get_args()): preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log - writer = SummaryWriter(args.logdir + '/' + 'ppo') + writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo')) - def stop_fn(x): + def stop_fn(mean_rewards): if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold else: return False diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 6345d62..4408283 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -6,11 +6,11 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.common import Net def get_args(): @@ -75,20 +75,20 @@ def test_dqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold - def train_fn(x): - if x <= int(0.1 * args.epoch): + def train_fn(epoch, env_step): + if env_step <= 100000: policy.set_eps(args.eps_train) - elif x <= int(0.5 * args.epoch): - eps = args.eps_train - (x - 0.1 * args.epoch) / \ - (0.4 * args.epoch) * (0.5 * args.eps_train) + elif env_step <= 500000: + eps = args.eps_train - (env_step - 100000) / \ + 400000 * (0.5 * args.eps_train) policy.set_eps(eps) else: policy.set_eps(0.5 * args.eps_train) - def test_fn(x): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # trainer diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index ffd3e8f..b4da185 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -74,9 +74,6 @@ class EnvWrapper(object): def test_sac_bipedal(args=get_args()): env = EnvWrapper(args.task) - def IsStop(reward): - return reward >= env.spec.reward_threshold - args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] @@ -141,11 +138,14 @@ def test_sac_bipedal(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) if __name__ == '__main__': diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index aa0f588..0bdb228 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -18,7 +18,7 @@ def get_args(): # the parameters are found by Optuna parser.add_argument('--task', type=str, default='LunarLander-v2') parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-test', type=float, default=0.01) parser.add_argument('--eps-train', type=float, default=0.73) parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument('--lr', type=float, default=0.013) @@ -77,14 +77,14 @@ def test_dqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold - def train_fn(x): - args.eps_train = max(args.eps_train * 0.6, 0.01) - policy.set_eps(args.eps_train) + def train_fn(epoch, env_step): # exp decay + eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) + policy.set_eps(eps) - def test_fn(x): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # trainer diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 9ca6845..b9481a5 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -98,8 +98,8 @@ def test_sac(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/ant_v2_ddpg.py b/examples/mujoco/ant_v2_ddpg.py index 948ceee..dd4486d 100644 --- a/examples/mujoco/ant_v2_ddpg.py +++ b/examples/mujoco/ant_v2_ddpg.py @@ -6,11 +6,11 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.exploration import GaussianNoise from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.exploration import GaussianNoise +from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -77,8 +77,8 @@ def test_ddpg(args=get_args()): # log writer = SummaryWriter(args.logdir + '/' + 'ddpg') - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/ant_v2_sac.py b/examples/mujoco/ant_v2_sac.py index a86bcff..819c745 100644 --- a/examples/mujoco/ant_v2_sac.py +++ b/examples/mujoco/ant_v2_sac.py @@ -7,10 +7,10 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -86,8 +86,8 @@ def test_sac(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/ant_v2_td3.py b/examples/mujoco/ant_v2_td3.py index 7165315..9e43f10 100644 --- a/examples/mujoco/ant_v2_td3.py +++ b/examples/mujoco/ant_v2_td3.py @@ -6,11 +6,11 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy +from tianshou.env import SubprocVectorEnv +from tianshou.utils.net.common import Net +from tianshou.exploration import GaussianNoise from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.env import SubprocVectorEnv -from tianshou.exploration import GaussianNoise -from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -88,8 +88,8 @@ def test_td3(args=get_args()): # log writer = SummaryWriter(args.logdir + '/' + 'td3') - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/halfcheetahBullet_v0_sac.py b/examples/mujoco/halfcheetahBullet_v0_sac.py index 97b3bc7..0559167 100644 --- a/examples/mujoco/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/halfcheetahBullet_v0_sac.py @@ -4,18 +4,14 @@ import torch import pprint import argparse import numpy as np +import pybullet_envs from torch.utils.tensorboard import SummaryWriter -from tianshou.env import SubprocVectorEnv from tianshou.policy import SACPolicy +from tianshou.utils.net.common import Net +from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -try: - import pybullet_envs -except ImportError: - pass -from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -91,8 +87,8 @@ def test_sac(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) writer = SummaryWriter(log_path) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/point_maze_td3.py b/examples/mujoco/point_maze_td3.py index 6de2b20..ff42716 100644 --- a/examples/mujoco/point_maze_td3.py +++ b/examples/mujoco/point_maze_td3.py @@ -6,12 +6,13 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.continuous import Actor, Critic + from mujoco.register import reg @@ -40,7 +41,6 @@ def get_args(): parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') - parser.add_argument('--max_episode_steps', type=int, default=2000) return parser.parse_args() @@ -91,9 +91,9 @@ def test_td3(args=get_args()): # log writer = SummaryWriter(args.logdir + '/' + 'td3') - def stop_fn(x): + def stop_fn(mean_rewards): if env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold else: return False diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 979444f..0991562 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -96,8 +96,8 @@ def test_ddpg(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index bee5af4..ef3692a 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -113,8 +113,8 @@ def test_ppo(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = onpolicy_trainer( diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 2067973..009218c 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -96,8 +96,8 @@ def test_sac_with_il(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index a6215e0..8479714 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -103,8 +103,8 @@ def test_td3(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index dda7704..b0c31e3 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -84,8 +84,8 @@ def test_a2c_with_il(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = onpolicy_trainer( diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 4d28d38..7564c08 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -85,21 +85,21 @@ def test_dqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold - def train_fn(x): + def train_fn(epoch, env_step): # eps annnealing, just a demo - if x <= int(0.1 * args.epoch): + if env_step <= 10000: policy.set_eps(args.eps_train) - elif x <= int(0.5 * args.epoch): - eps = args.eps_train - (x - 0.1 * args.epoch) / \ - (0.4 * args.epoch) * (0.9 * args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) policy.set_eps(eps) else: policy.set_eps(0.1 * args.eps_train) - def test_fn(x): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # trainer diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 5ef6c16..f3f00e6 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -79,13 +79,13 @@ def test_drqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold - def train_fn(x): + def train_fn(epoch, env_step): policy.set_eps(args.eps_train) - def test_fn(x): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # trainer diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 3604adb..d84130b 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -73,8 +73,8 @@ def test_pg(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = onpolicy_trainer( diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index c8d8494..f6d23fe 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -98,8 +98,8 @@ def test_ppo(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = onpolicy_trainer( diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 865924a..4b86073 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -93,8 +93,8 @@ def test_discrete_sac(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 6fb0e16..29dfb6b 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -66,9 +66,9 @@ def test_psrl(args=get_args()): # log writer = SummaryWriter(args.logdir + '/' + args.task) - def stop_fn(x): + def stop_fn(mean_rewards): if env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold else: return False diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 9110b9d..1321833 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -64,11 +64,12 @@ def get_args() -> argparse.Namespace: return args -def get_agents(args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - optim: Optional[torch.optim.Optimizer] = None, - ) -> Tuple[BasePolicy, torch.optim.Optimizer]: +def get_agents( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[BasePolicy, torch.optim.Optimizer]: env = TicTacToeEnv(args.board_size, args.win_size) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n @@ -99,11 +100,12 @@ def get_agents(args: argparse.Namespace = get_args(), return policy, optim -def train_agent(args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - optim: Optional[torch.optim.Optimizer] = None, - ) -> Tuple[dict, BasePolicy]: +def train_agent( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[dict, BasePolicy]: def env_func(): return TicTacToeEnv(args.board_size, args.win_size) train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)]) @@ -142,13 +144,13 @@ def train_agent(args: argparse.Namespace = get_args(), policy.policies[args.agent_id - 1].state_dict(), model_save_path) - def stop_fn(x): - return x >= args.win_rate + def stop_fn(mean_rewards): + return mean_rewards >= args.win_rate - def train_fn(x): + def train_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) - def test_fn(x): + def test_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) # trainer @@ -162,10 +164,11 @@ def train_agent(args: argparse.Namespace = get_args(), return result, policy.policies[args.agent_id - 1] -def watch(args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - ) -> None: +def watch( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, +) -> None: env = TicTacToeEnv(args.board_size, args.win_size) policy, optim = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 0b9c0e9..cc37f27 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, exploration -__version__ = "0.3.0rc0" +__version__ = "0.3.0" __all__ = [ "env", diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 8cf4f6b..c8c572a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -230,7 +230,8 @@ class ReplayBuffer: obs = obs[-1] self._add_to_buffer("obs", obs) self._add_to_buffer("act", act) - self._add_to_buffer("rew", rew) + # make sure the reward is a float instead of an int + self._add_to_buffer("rew", rew * 1.0) # type: ignore self._add_to_buffer("done", done) if self._save_s_: if obs_next is None: diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 170fd68..75fd6cf 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -19,8 +19,8 @@ def offpolicy_trainer( episode_per_test: Union[int, List[int]], batch_size: int, update_per_step: int = 1, - train_fn: Optional[Callable[[int], None]] = None, - test_fn: Optional[Callable[[int], None]] = None, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, writer: Optional[SummaryWriter] = None, @@ -53,11 +53,11 @@ def offpolicy_trainer( it updates policy 256 times once after ``collect_per_step`` frames are collected. :param function train_fn: a function receives the current number of epoch - index and performs some operations at the beginning of training in this - epoch. + and step index, and performs some operations at the beginning of + training in this epoch. :param function test_fn: a function receives the current number of epoch - index and performs some operations at the beginning of testing in this - epoch. + and step index, and performs some operations at the beginning of + testing in this epoch. :param function save_fn: a function for saving policy when the undiscounted average mean reward in evaluation phase gets better. :param function stop_fn: a function receives the average undiscounted @@ -81,12 +81,12 @@ def offpolicy_trainer( for epoch in range(1, 1 + max_epoch): # train policy.train() - if train_fn: - train_fn(epoch) with tqdm.tqdm( total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config ) as t: while t.n < t.total: + if train_fn: + train_fn(epoch, global_step) result = train_collector.collect(n_step=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result["rew"]): @@ -104,8 +104,6 @@ def offpolicy_trainer( test_result["rew"]) else: policy.train() - if train_fn: - train_fn(epoch) for i in range(update_per_step * min( result["n/st"] // collect_per_step, t.total - t.n)): global_step += collect_per_step diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 877c634..023dd29 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -19,8 +19,8 @@ def onpolicy_trainer( repeat_per_collect: int, episode_per_test: Union[int, List[int]], batch_size: int, - train_fn: Optional[Callable[[int], None]] = None, - test_fn: Optional[Callable[[int], None]] = None, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, writer: Optional[SummaryWriter] = None, @@ -53,11 +53,11 @@ def onpolicy_trainer( :param int batch_size: the batch size of sample data, which is going to feed in the policy network. :param function train_fn: a function receives the current number of epoch - index and performs some operations at the beginning of training in this - epoch. + and step index, and performs some operations at the beginning of + training in this poch. :param function test_fn: a function receives the current number of epoch - index and performs some operations at the beginning of testing in this - epoch. + and step index, and performs some operations at the beginning of + testing in this epoch. :param function save_fn: a function for saving policy when the undiscounted average mean reward in evaluation phase gets better. :param function stop_fn: a function receives the average undiscounted @@ -81,12 +81,12 @@ def onpolicy_trainer( for epoch in range(1, 1 + max_epoch): # train policy.train() - if train_fn: - train_fn(epoch) with tqdm.tqdm( total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config ) as t: while t.n < t.total: + if train_fn: + train_fn(epoch, global_step) result = train_collector.collect(n_episode=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result["rew"]): @@ -104,8 +104,6 @@ def onpolicy_trainer( test_result["rew"]) else: policy.train() - if train_fn: - train_fn(epoch) losses = policy.update( 0, train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect) diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 0c5d2dd..2c2fb54 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -10,7 +10,7 @@ from tianshou.policy import BasePolicy def test_episode( policy: BasePolicy, collector: Collector, - test_fn: Optional[Callable[[int], None]], + test_fn: Optional[Callable[[int, Optional[int]], None]], epoch: int, n_episode: Union[int, List[int]], writer: Optional[SummaryWriter] = None, @@ -21,7 +21,7 @@ def test_episode( collector.reset_buffer() policy.eval() if test_fn: - test_fn(epoch) + test_fn(epoch, global_step) if collector.get_env_num() > 1 and isinstance(n_episode, int): n = collector.get_env_num() n_ = np.zeros(n) + n_episode // n