diff --git a/examples/halfcheetahBullet_v0_sac.py b/examples/halfcheetahBullet_v0_sac.py new file mode 100644 index 0000000..1c90708 --- /dev/null +++ b/examples/halfcheetahBullet_v0_sac.py @@ -0,0 +1,120 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import SubprocVectorEnv +from tianshou.policy import SACPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer + +try: + import pybullet_envs +except ImportError: + pass + +from continuous_net import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='HalfCheetahBulletEnv-v0') + parser.add_argument('--run-id', type=str, default='test') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--actor-lr', type=float, default=3e-4) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--epoch', type=int, default=200) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--layer-num', type=int, default=1) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=4) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--log-interval', type=int, default=100) + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_sac(args=get_args()): + torch.set_num_threads(1) + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # you can also use tianshou.env.SubprocVectorEnv + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + actor = ActorProb( + args.layer_num, args.state_shape, args.action_shape, + args.max_action, args.device + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + critic1 = Critic( + args.layer_num, args.state_shape, args.action_shape, args.device + ).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic( + args.layer_num, args.state_shape, args.action_shape, args.device + ).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + policy = SACPolicy( + actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, + args.tau, args.gamma, args.alpha, + [env.action_space.low[0], env.action_space.high[0]], + reward_normalization=True, ignore_done=True) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # train_collector.collect(n_step=args.buffer_size) + # log + log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) + writer = SummaryWriter(log_path) + + def stop_fn(x): + return x >= env.spec.reward_threshold + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, stop_fn=stop_fn, + writer=writer, log_interval=args.log_interval) + assert stop_fn(result['best_reward']) + train_collector.close() + test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + + +if __name__ == '__main__': + __all__ = ('pybullet_envs',) # Avoid F401 error :) + test_sac() diff --git a/setup.py b/setup.py index 29bdb4a..5e5e6fa 100644 --- a/setup.py +++ b/setup.py @@ -69,5 +69,8 @@ setup( 'mujoco': [ 'mujoco_py', ], + 'pybullet': [ + 'pybullet', + ], }, ) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 470f269..7efec91 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -8,7 +8,7 @@ from tianshou.trainer import test_episode, gather_info def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, step_per_epoch, collect_per_step, episode_per_test, batch_size, train_fn=None, test_fn=None, stop_fn=None, - writer=None, verbose=True, task=''): + writer=None, log_interval=1, verbose=True, task=''): global_step = 0 best_epoch, best_reward = -1, -1 stat = {} @@ -45,7 +45,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, losses = policy.learn(train_collector.sample(batch_size)) for k in result.keys(): data[k] = f'{result[k]:.2f}' - if writer: + if writer and global_step % log_interval == 0: writer.add_scalar( k + '_' + task if task else k, result[k], global_step=global_step) @@ -54,7 +54,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, stat[k] = MovAvg() stat[k].add(losses[k]) data[k] = f'{stat[k].get():.6f}' - if writer: + if writer and global_step % log_interval == 0: writer.add_scalar( k + '_' + task if task else k, stat[k].get(), global_step=global_step) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index a88d0fd..29f0a43 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -9,7 +9,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, step_per_epoch, collect_per_step, repeat_per_collect, episode_per_test, batch_size, train_fn=None, test_fn=None, stop_fn=None, - writer=None, verbose=True, task=''): + writer=None, log_interval=1, verbose=True, task=''): global_step = 0 best_epoch, best_reward = -1, -1 stat = {} @@ -50,7 +50,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, global_step += step for k in result.keys(): data[k] = f'{result[k]:.2f}' - if writer: + if writer and global_step % log_interval == 0: writer.add_scalar( k + '_' + task if task else k, result[k], global_step=global_step) @@ -59,7 +59,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, stat[k] = MovAvg() stat[k].add(losses[k]) data[k] = f'{stat[k].get():.6f}' - if writer and global_step: + if writer and global_step % log_interval == 0: writer.add_scalar( k + '_' + task if task else k, stat[k].get(), global_step=global_step)