Performance improve (#18)

* improve performance

set one thread for NN
replace detach() op with torch.no_grad()

* fix pep 8 errors
This commit is contained in:
Oblivion 2020-04-05 09:10:21 +08:00 committed by GitHub
parent b6c9db6b0b
commit 4d4d0daf9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 70 additions and 44 deletions

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -19,14 +20,15 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--exploration-noise', type=float, default=0.1)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=128)
@ -43,6 +45,7 @@ def get_args():
def test_ddpg(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
@ -81,7 +84,8 @@ def test_ddpg(args=get_args()):
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir + '/' + 'ddpg')
log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x):
return x >= env.spec.reward_threshold

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -19,14 +20,15 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=16)
@ -47,6 +49,7 @@ def get_args():
def _test_ppo(args=get_args()):
# just a demo, I have not made it work :(
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
@ -89,7 +92,8 @@ def _test_ppo(args=get_args()):
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch)
# log
writer = SummaryWriter(args.logdir + '/' + 'ppo')
log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x):
return x >= env.spec.reward_threshold

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -19,14 +20,15 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
@ -43,6 +45,7 @@ def get_args():
def test_sac(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
@ -86,7 +89,8 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'sac')
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x):
return x >= env.spec.reward_threshold

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -19,7 +20,8 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--run-id', type=str, default='test')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
@ -29,7 +31,7 @@ def get_args():
parser.add_argument('--policy-noise', type=float, default=0.2)
parser.add_argument('--noise-clip', type=float, default=0.5)
parser.add_argument('--update-actor-freq', type=int, default=2)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
@ -46,6 +48,7 @@ def get_args():
def test_td3(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
@ -90,7 +93,8 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'td3')
log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id)
writer = SummaryWriter(log_path)
def stop_fn(x):
return x >= env.spec.reward_threshold

View File

@ -3,7 +3,7 @@ import torch
import warnings
import numpy as np
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer,\
from tianshou.data import Batch, ReplayBuffer, \
ListReplayBuffer
from tianshou.utils import MovAvg
@ -115,7 +115,8 @@ class Collector(object):
done=self._make_batch(self._done),
obs_next=None,
info=self._make_batch(self._info))
result = self.policy(batch_data, self.state)
with torch.no_grad():
result = self.policy(batch_data, self.state)
self.state = result.state if hasattr(result, 'state') else None
if isinstance(result.act, torch.Tensor):
self._act = result.act.detach().cpu().numpy()

View File

@ -90,12 +90,15 @@ class DDPGPolicy(BasePolicy):
return Batch(act=logits, state=h)
def learn(self, batch, batch_size=None, repeat=1):
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device
rew = torch.tensor(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
current_q = self.critic(batch.obs, batch.act)
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()

View File

@ -62,17 +62,20 @@ class SACPolicy(DDPGPolicy):
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def learn(self, batch, batch_size=None, repeat=1):
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
dev = a_.device
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
with torch.no_grad():
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
dev = a_.device
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob
rew = torch.tensor(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
obs_result = self(batch)
a = obs_result.act
current_q1, current_q1a = self.critic1(

View File

@ -51,19 +51,22 @@ class TD3Policy(DDPGPolicy):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def learn(self, batch, batch_size=None, repeat=1):
a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
if self._noise_clip >= 0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
a_ = a_.clamp(self._range[0], self._range[1])
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_))
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
with torch.no_grad():
a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
if self._noise_clip >= 0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
a_ = a_.clamp(self._range[0], self._range[1])
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_))
rew = torch.tensor(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
critic1_loss = F.mse_loss(current_q1, target_q)