Performance improve (#18)

* improve performance

set one thread for NN
replace detach() op with torch.no_grad()

* fix pep 8 errors
This commit is contained in:
Oblivion 2020-04-05 09:10:21 +08:00 committed by GitHub
parent b6c9db6b0b
commit 4d4d0daf9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 70 additions and 44 deletions

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -19,14 +20,15 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-4) parser.add_argument('--actor-lr', type=float, default=1e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--exploration-noise', type=float, default=0.1)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=4) parser.add_argument('--collect-per-step', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--batch-size', type=int, default=128)
@ -43,6 +45,7 @@ def get_args():
def test_ddpg(args=get_args()): def test_ddpg(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250 env.spec.reward_threshold = -250
@ -81,7 +84,8 @@ def test_ddpg(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'ddpg') log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -19,14 +20,15 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1) parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--training-num', type=int, default=16)
@ -47,6 +49,7 @@ def get_args():
def _test_ppo(args=get_args()): def _test_ppo(args=get_args()):
# just a demo, I have not made it work :( # just a demo, I have not made it work :(
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250 env.spec.reward_threshold = -250
@ -89,7 +92,8 @@ def _test_ppo(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch) train_collector.collect(n_step=args.step_per_epoch)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'ppo') log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -19,14 +20,15 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--batch-size', type=int, default=128)
@ -43,6 +45,7 @@ def get_args():
def test_sac(args=get_args()): def test_sac(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250 env.spec.reward_threshold = -250
@ -86,7 +89,8 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size) # train_collector.collect(n_step=args.buffer_size)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'sac') log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -19,7 +20,8 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4) parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3)
@ -29,7 +31,7 @@ def get_args():
parser.add_argument('--policy-noise', type=float, default=0.2) parser.add_argument('--policy-noise', type=float, default=0.2)
parser.add_argument('--noise-clip', type=float, default=0.5) parser.add_argument('--noise-clip', type=float, default=0.5)
parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--update-actor-freq', type=int, default=2)
parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400) parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--batch-size', type=int, default=128)
@ -46,6 +48,7 @@ def get_args():
def test_td3(args=get_args()): def test_td3(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task) env = gym.make(args.task)
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250 env.spec.reward_threshold = -250
@ -90,7 +93,8 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size) # train_collector.collect(n_step=args.buffer_size)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'td3') log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x): def stop_fn(x):
return x >= env.spec.reward_threshold return x >= env.spec.reward_threshold

View File

@ -3,7 +3,7 @@ import torch
import warnings import warnings
import numpy as np import numpy as np
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer,\ from tianshou.data import Batch, ReplayBuffer, \
ListReplayBuffer ListReplayBuffer
from tianshou.utils import MovAvg from tianshou.utils import MovAvg
@ -115,7 +115,8 @@ class Collector(object):
done=self._make_batch(self._done), done=self._make_batch(self._done),
obs_next=None, obs_next=None,
info=self._make_batch(self._info)) info=self._make_batch(self._info))
result = self.policy(batch_data, self.state) with torch.no_grad():
result = self.policy(batch_data, self.state)
self.state = result.state if hasattr(result, 'state') else None self.state = result.state if hasattr(result, 'state') else None
if isinstance(result.act, torch.Tensor): if isinstance(result.act, torch.Tensor):
self._act = result.act.detach().cpu().numpy() self._act = result.act.detach().cpu().numpy()

View File

@ -90,12 +90,15 @@ class DDPGPolicy(BasePolicy):
return Batch(act=logits, state=h) return Batch(act=logits, state=h)
def learn(self, batch, batch_size=None, repeat=1): def learn(self, batch, batch_size=None, repeat=1):
target_q = self.critic_old(batch.obs_next, self( with torch.no_grad():
batch, model='actor_old', input='obs_next', eps=0).act) target_q = self.critic_old(batch.obs_next, self(
dev = target_q.device batch, model='actor_old', input='obs_next', eps=0).act)
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] dev = target_q.device
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] rew = torch.tensor(batch.rew,
target_q = (rew + (1. - done) * self._gamma * target_q).detach() dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
current_q = self.critic(batch.obs, batch.act) current_q = self.critic(batch.obs, batch.act)
critic_loss = F.mse_loss(current_q, target_q) critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad() self.critic_optim.zero_grad()

View File

@ -62,17 +62,20 @@ class SACPolicy(DDPGPolicy):
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def learn(self, batch, batch_size=None, repeat=1): def learn(self, batch, batch_size=None, repeat=1):
obs_next_result = self(batch, input='obs_next') with torch.no_grad():
a_ = obs_next_result.act obs_next_result = self(batch, input='obs_next')
dev = a_.device a_ = obs_next_result.act
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) dev = a_.device
target_q = torch.min( batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
self.critic1_old(batch.obs_next, a_), target_q = torch.min(
self.critic2_old(batch.obs_next, a_), self.critic1_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob self.critic2_old(batch.obs_next, a_),
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] ) - self._alpha * obs_next_result.log_prob
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] rew = torch.tensor(batch.rew,
target_q = (rew + (1. - done) * self._gamma * target_q).detach() dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
obs_result = self(batch) obs_result = self(batch)
a = obs_result.act a = obs_result.act
current_q1, current_q1a = self.critic1( current_q1, current_q1a = self.critic1(

View File

@ -51,19 +51,22 @@ class TD3Policy(DDPGPolicy):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def learn(self, batch, batch_size=None, repeat=1): def learn(self, batch, batch_size=None, repeat=1):
a_ = self(batch, model='actor_old', input='obs_next').act with torch.no_grad():
dev = a_.device a_ = self(batch, model='actor_old', input='obs_next').act
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise dev = a_.device
if self._noise_clip >= 0: noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
noise = noise.clamp(-self._noise_clip, self._noise_clip) if self._noise_clip >= 0:
a_ += noise noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ = a_.clamp(self._range[0], self._range[1]) a_ += noise
target_q = torch.min( a_ = a_.clamp(self._range[0], self._range[1])
self.critic1_old(batch.obs_next, a_), target_q = torch.min(
self.critic2_old(batch.obs_next, a_)) self.critic1_old(batch.obs_next, a_),
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] self.critic2_old(batch.obs_next, a_))
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] rew = torch.tensor(batch.rew,
target_q = (rew + (1. - done) * self._gamma * target_q).detach() dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1 # critic 1
current_q1 = self.critic1(batch.obs, batch.act) current_q1 = self.critic1(batch.obs, batch.act)
critic1_loss = F.mse_loss(current_q1, target_q) critic1_loss = F.mse_loss(current_q1, target_q)