diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 7813e0c..1972c75 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -19,14 +20,15 @@ else: # pytest def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--run-id', type=str, default='test') + parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=1e-4) parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--exploration-noise', type=float, default=0.1) - parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--collect-per-step', type=int, default=4) parser.add_argument('--batch-size', type=int, default=128) @@ -43,6 +45,7 @@ def get_args(): def test_ddpg(args=get_args()): + torch.set_num_threads(1) # we just need only one thread for NN env = gym.make(args.task) if args.task == 'Pendulum-v0': env.spec.reward_threshold = -250 @@ -81,7 +84,8 @@ def test_ddpg(args=get_args()): policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + 'ddpg') + log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id) + writer = SummaryWriter(log_path) def stop_fn(x): return x >= env.spec.reward_threshold diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 6fe769c..6bbbdc5 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -19,14 +20,15 @@ else: # pytest def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--run-id', type=str, default='test') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--repeat-per-collect', type=int, default=2) + parser.add_argument('--repeat-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=1) parser.add_argument('--training-num', type=int, default=16) @@ -47,6 +49,7 @@ def get_args(): def _test_ppo(args=get_args()): # just a demo, I have not made it work :( + torch.set_num_threads(1) # we just need only one thread for NN env = gym.make(args.task) if args.task == 'Pendulum-v0': env.spec.reward_threshold = -250 @@ -89,7 +92,8 @@ def _test_ppo(args=get_args()): test_collector = Collector(policy, test_envs) train_collector.collect(n_step=args.step_per_epoch) # log - writer = SummaryWriter(args.logdir + '/' + 'ppo') + log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id) + writer = SummaryWriter(log_path) def stop_fn(x): return x >= env.spec.reward_threshold diff --git a/test/continuous/test_sac.py b/test/continuous/test_sac.py index 6fe6f75..2d5b3df 100644 --- a/test/continuous/test_sac.py +++ b/test/continuous/test_sac.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -19,14 +20,15 @@ else: # pytest def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--run-id', type=str, default='test') + parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.2) - parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) @@ -43,6 +45,7 @@ def get_args(): def test_sac(args=get_args()): + torch.set_num_threads(1) # we just need only one thread for NN env = gym.make(args.task) if args.task == 'Pendulum-v0': env.spec.reward_threshold = -250 @@ -86,7 +89,8 @@ def test_sac(args=get_args()): test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log - writer = SummaryWriter(args.logdir + '/' + 'sac') + log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) + writer = SummaryWriter(log_path) def stop_fn(x): return x >= env.spec.reward_threshold diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index da9637c..e7e82c9 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -19,7 +20,8 @@ else: # pytest def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--run-id', type=str, default='test') + parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--critic-lr', type=float, default=1e-3) @@ -29,7 +31,7 @@ def get_args(): parser.add_argument('--policy-noise', type=float, default=0.2) parser.add_argument('--noise-clip', type=float, default=0.5) parser.add_argument('--update-actor-freq', type=int, default=2) - parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) @@ -46,6 +48,7 @@ def get_args(): def test_td3(args=get_args()): + torch.set_num_threads(1) # we just need only one thread for NN env = gym.make(args.task) if args.task == 'Pendulum-v0': env.spec.reward_threshold = -250 @@ -90,7 +93,8 @@ def test_td3(args=get_args()): test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log - writer = SummaryWriter(args.logdir + '/' + 'td3') + log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id) + writer = SummaryWriter(log_path) def stop_fn(x): return x >= env.spec.reward_threshold diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 5d090b8..2c94e07 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -3,7 +3,7 @@ import torch import warnings import numpy as np from tianshou.env import BaseVectorEnv -from tianshou.data import Batch, ReplayBuffer,\ +from tianshou.data import Batch, ReplayBuffer, \ ListReplayBuffer from tianshou.utils import MovAvg @@ -115,7 +115,8 @@ class Collector(object): done=self._make_batch(self._done), obs_next=None, info=self._make_batch(self._info)) - result = self.policy(batch_data, self.state) + with torch.no_grad(): + result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None if isinstance(result.act, torch.Tensor): self._act = result.act.detach().cpu().numpy() diff --git a/tianshou/policy/ddpg.py b/tianshou/policy/ddpg.py index 8875550..5bdf19c 100644 --- a/tianshou/policy/ddpg.py +++ b/tianshou/policy/ddpg.py @@ -90,12 +90,15 @@ class DDPGPolicy(BasePolicy): return Batch(act=logits, state=h) def learn(self, batch, batch_size=None, repeat=1): - target_q = self.critic_old(batch.obs_next, self( - batch, model='actor_old', input='obs_next', eps=0).act) - dev = target_q.device - rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] - done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] - target_q = (rew + (1. - done) * self._gamma * target_q).detach() + with torch.no_grad(): + target_q = self.critic_old(batch.obs_next, self( + batch, model='actor_old', input='obs_next', eps=0).act) + dev = target_q.device + rew = torch.tensor(batch.rew, + dtype=torch.float, device=dev)[:, None] + done = torch.tensor(batch.done, + dtype=torch.float, device=dev)[:, None] + target_q = (rew + (1. - done) * self._gamma * target_q) current_q = self.critic(batch.obs, batch.act) critic_loss = F.mse_loss(current_q, target_q) self.critic_optim.zero_grad() diff --git a/tianshou/policy/sac.py b/tianshou/policy/sac.py index 2c534dc..bc231db 100644 --- a/tianshou/policy/sac.py +++ b/tianshou/policy/sac.py @@ -62,17 +62,20 @@ class SACPolicy(DDPGPolicy): logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) def learn(self, batch, batch_size=None, repeat=1): - obs_next_result = self(batch, input='obs_next') - a_ = obs_next_result.act - dev = a_.device - batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) - target_q = torch.min( - self.critic1_old(batch.obs_next, a_), - self.critic2_old(batch.obs_next, a_), - ) - self._alpha * obs_next_result.log_prob - rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] - done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] - target_q = (rew + (1. - done) * self._gamma * target_q).detach() + with torch.no_grad(): + obs_next_result = self(batch, input='obs_next') + a_ = obs_next_result.act + dev = a_.device + batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) + target_q = torch.min( + self.critic1_old(batch.obs_next, a_), + self.critic2_old(batch.obs_next, a_), + ) - self._alpha * obs_next_result.log_prob + rew = torch.tensor(batch.rew, + dtype=torch.float, device=dev)[:, None] + done = torch.tensor(batch.done, + dtype=torch.float, device=dev)[:, None] + target_q = (rew + (1. - done) * self._gamma * target_q) obs_result = self(batch) a = obs_result.act current_q1, current_q1a = self.critic1( diff --git a/tianshou/policy/td3.py b/tianshou/policy/td3.py index 03aff0a..d72cb5b 100644 --- a/tianshou/policy/td3.py +++ b/tianshou/policy/td3.py @@ -51,19 +51,22 @@ class TD3Policy(DDPGPolicy): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) def learn(self, batch, batch_size=None, repeat=1): - a_ = self(batch, model='actor_old', input='obs_next').act - dev = a_.device - noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise - if self._noise_clip >= 0: - noise = noise.clamp(-self._noise_clip, self._noise_clip) - a_ += noise - a_ = a_.clamp(self._range[0], self._range[1]) - target_q = torch.min( - self.critic1_old(batch.obs_next, a_), - self.critic2_old(batch.obs_next, a_)) - rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] - done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] - target_q = (rew + (1. - done) * self._gamma * target_q).detach() + with torch.no_grad(): + a_ = self(batch, model='actor_old', input='obs_next').act + dev = a_.device + noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise + if self._noise_clip >= 0: + noise = noise.clamp(-self._noise_clip, self._noise_clip) + a_ += noise + a_ = a_.clamp(self._range[0], self._range[1]) + target_q = torch.min( + self.critic1_old(batch.obs_next, a_), + self.critic2_old(batch.obs_next, a_)) + rew = torch.tensor(batch.rew, + dtype=torch.float, device=dev)[:, None] + done = torch.tensor(batch.done, + dtype=torch.float, device=dev)[:, None] + target_q = (rew + (1. - done) * self._gamma * target_q) # critic 1 current_q1 = self.critic1(batch.obs, batch.act) critic1_loss = F.mse_loss(current_q1, target_q)