Performance improve (#18)

* improve performance

set one thread for NN
replace detach() op with torch.no_grad()

* fix pep 8 errors
This commit is contained in:
Oblivion 2020-04-05 09:10:21 +08:00 committed by GitHub
parent b6c9db6b0b
commit 4d4d0daf9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 70 additions and 44 deletions

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -19,14 +20,15 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-4) parser.add_argument('--actor-lr', type=float, default=1e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--exploration-noise', type=float, default=0.1)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=4) parser.add_argument('--collect-per-step', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--batch-size', type=int, default=128)
@ -43,6 +45,7 @@ def get_args():
def test_ddpg(args=get_args()): def test_ddpg(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250 env.spec.reward_threshold = -250
@ -81,7 +84,8 @@ def test_ddpg(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'ddpg') log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -19,14 +20,15 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1) parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--training-num', type=int, default=16)
@ -47,6 +49,7 @@ def get_args():
def _test_ppo(args=get_args()): def _test_ppo(args=get_args()):
# just a demo, I have not made it work :( # just a demo, I have not made it work :(
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250 env.spec.reward_threshold = -250
@ -89,7 +92,8 @@ def _test_ppo(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch) train_collector.collect(n_step=args.step_per_epoch)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'ppo') log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -19,14 +20,15 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--batch-size', type=int, default=128)
@ -43,6 +45,7 @@ def get_args():
def test_sac(args=get_args()): def test_sac(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250 env.spec.reward_threshold = -250
@ -86,7 +89,8 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size) # train_collector.collect(n_step=args.buffer_size)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'sac') log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -19,7 +20,8 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
@ -29,7 +31,7 @@ def get_args():
parser.add_argument('--policy-noise', type=float, default=0.2) parser.add_argument('--policy-noise', type=float, default=0.2)
parser.add_argument('--noise-clip', type=float, default=0.5) parser.add_argument('--noise-clip', type=float, default=0.5)
parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--update-actor-freq', type=int, default=2)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--batch-size', type=int, default=128)
@ -46,6 +48,7 @@ def get_args():
def test_td3(args=get_args()): def test_td3(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250 env.spec.reward_threshold = -250
@ -90,7 +93,8 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size) # train_collector.collect(n_step=args.buffer_size)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'td3') log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold

View File

@ -115,6 +115,7 @@ class Collector(object):
done=self._make_batch(self._done), done=self._make_batch(self._done),
obs_next=None, obs_next=None,
info=self._make_batch(self._info)) info=self._make_batch(self._info))
with torch.no_grad():
result = self.policy(batch_data, self.state) result = self.policy(batch_data, self.state)
self.state = result.state if hasattr(result, 'state') else None self.state = result.state if hasattr(result, 'state') else None
if isinstance(result.act, torch.Tensor): if isinstance(result.act, torch.Tensor):

View File

@ -90,12 +90,15 @@ class DDPGPolicy(BasePolicy):
return Batch(act=logits, state=h) return Batch(act=logits, state=h)
def learn(self, batch, batch_size=None, repeat=1): def learn(self, batch, batch_size=None, repeat=1):
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self( target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act) batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device dev = target_q.device
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] rew = torch.tensor(batch.rew,
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q).detach() done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
current_q = self.critic(batch.obs, batch.act) current_q = self.critic(batch.obs, batch.act)
critic_loss = F.mse_loss(current_q, target_q) critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad() self.critic_optim.zero_grad()

View File

@ -62,6 +62,7 @@ class SACPolicy(DDPGPolicy):
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def learn(self, batch, batch_size=None, repeat=1): def learn(self, batch, batch_size=None, repeat=1):
with torch.no_grad():
obs_next_result = self(batch, input='obs_next') obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act a_ = obs_next_result.act
dev = a_.device dev = a_.device
@ -70,9 +71,11 @@ class SACPolicy(DDPGPolicy):
self.critic1_old(batch.obs_next, a_), self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob ) - self._alpha * obs_next_result.log_prob
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] rew = torch.tensor(batch.rew,
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q).detach() done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
obs_result = self(batch) obs_result = self(batch)
a = obs_result.act a = obs_result.act
current_q1, current_q1a = self.critic1( current_q1, current_q1a = self.critic1(

View File

@ -51,6 +51,7 @@ class TD3Policy(DDPGPolicy):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def learn(self, batch, batch_size=None, repeat=1): def learn(self, batch, batch_size=None, repeat=1):
with torch.no_grad():
a_ = self(batch, model='actor_old', input='obs_next').act a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
@ -61,9 +62,11 @@ class TD3Policy(DDPGPolicy):
target_q = torch.min( target_q = torch.min(
self.critic1_old(batch.obs_next, a_), self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_)) self.critic2_old(batch.obs_next, a_))
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] rew = torch.tensor(batch.rew,
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q).detach() done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1 # critic 1
current_q1 = self.critic1(batch.obs, batch.act) current_q1 = self.critic1(batch.obs, batch.act)
critic1_loss = F.mse_loss(current_q1, target_q) critic1_loss = F.mse_loss(current_q1, target_q)