Performance improve (#18)
* improve performance set one thread for NN replace detach() op with torch.no_grad() * fix pep 8 errors
This commit is contained in:
parent
b6c9db6b0b
commit
4d4d0daf9e
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -19,14 +20,15 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--run-id', type=str, default='test')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
||||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--tau', type=float, default=0.005)
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=4)
|
parser.add_argument('--collect-per-step', type=int, default=4)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
@ -43,6 +45,7 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def test_ddpg(args=get_args()):
|
def test_ddpg(args=get_args()):
|
||||||
|
torch.set_num_threads(1) # we just need only one thread for NN
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
if args.task == 'Pendulum-v0':
|
if args.task == 'Pendulum-v0':
|
||||||
env.spec.reward_threshold = -250
|
env.spec.reward_threshold = -250
|
||||||
@ -81,7 +84,8 @@ def test_ddpg(args=get_args()):
|
|||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'ddpg')
|
log_path = os.path.join(args.logdir, args.task, 'ddpg', args.run_id)
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -19,14 +20,15 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
|
parser.add_argument('--run-id', type=str, default='test')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
parser.add_argument('--repeat-per-collect', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=1)
|
parser.add_argument('--layer-num', type=int, default=1)
|
||||||
parser.add_argument('--training-num', type=int, default=16)
|
parser.add_argument('--training-num', type=int, default=16)
|
||||||
@ -47,6 +49,7 @@ def get_args():
|
|||||||
|
|
||||||
def _test_ppo(args=get_args()):
|
def _test_ppo(args=get_args()):
|
||||||
# just a demo, I have not made it work :(
|
# just a demo, I have not made it work :(
|
||||||
|
torch.set_num_threads(1) # we just need only one thread for NN
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
if args.task == 'Pendulum-v0':
|
if args.task == 'Pendulum-v0':
|
||||||
env.spec.reward_threshold = -250
|
env.spec.reward_threshold = -250
|
||||||
@ -89,7 +92,8 @@ def _test_ppo(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
train_collector.collect(n_step=args.step_per_epoch)
|
train_collector.collect(n_step=args.step_per_epoch)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
log_path = os.path.join(args.logdir, args.task, 'ppo', args.run_id)
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -19,14 +20,15 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--run-id', type=str, default='test')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--tau', type=float, default=0.005)
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
parser.add_argument('--alpha', type=float, default=0.2)
|
parser.add_argument('--alpha', type=float, default=0.2)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
@ -43,6 +45,7 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def test_sac(args=get_args()):
|
def test_sac(args=get_args()):
|
||||||
|
torch.set_num_threads(1) # we just need only one thread for NN
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
if args.task == 'Pendulum-v0':
|
if args.task == 'Pendulum-v0':
|
||||||
env.spec.reward_threshold = -250
|
env.spec.reward_threshold = -250
|
||||||
@ -86,7 +89,8 @@ def test_sac(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# train_collector.collect(n_step=args.buffer_size)
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'sac')
|
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
@ -19,7 +20,8 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--run-id', type=str, default='test')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
@ -29,7 +31,7 @@ def get_args():
|
|||||||
parser.add_argument('--policy-noise', type=float, default=0.2)
|
parser.add_argument('--policy-noise', type=float, default=0.2)
|
||||||
parser.add_argument('--noise-clip', type=float, default=0.5)
|
parser.add_argument('--noise-clip', type=float, default=0.5)
|
||||||
parser.add_argument('--update-actor-freq', type=int, default=2)
|
parser.add_argument('--update-actor-freq', type=int, default=2)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=20)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
@ -46,6 +48,7 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def test_td3(args=get_args()):
|
def test_td3(args=get_args()):
|
||||||
|
torch.set_num_threads(1) # we just need only one thread for NN
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
if args.task == 'Pendulum-v0':
|
if args.task == 'Pendulum-v0':
|
||||||
env.spec.reward_threshold = -250
|
env.spec.reward_threshold = -250
|
||||||
@ -90,7 +93,8 @@ def test_td3(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# train_collector.collect(n_step=args.buffer_size)
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'td3')
|
log_path = os.path.join(args.logdir, args.task, 'td3', args.run_id)
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
def stop_fn(x):
|
def stop_fn(x):
|
||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
@ -3,7 +3,7 @@ import torch
|
|||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.data import Batch, ReplayBuffer,\
|
from tianshou.data import Batch, ReplayBuffer, \
|
||||||
ListReplayBuffer
|
ListReplayBuffer
|
||||||
from tianshou.utils import MovAvg
|
from tianshou.utils import MovAvg
|
||||||
|
|
||||||
@ -115,7 +115,8 @@ class Collector(object):
|
|||||||
done=self._make_batch(self._done),
|
done=self._make_batch(self._done),
|
||||||
obs_next=None,
|
obs_next=None,
|
||||||
info=self._make_batch(self._info))
|
info=self._make_batch(self._info))
|
||||||
result = self.policy(batch_data, self.state)
|
with torch.no_grad():
|
||||||
|
result = self.policy(batch_data, self.state)
|
||||||
self.state = result.state if hasattr(result, 'state') else None
|
self.state = result.state if hasattr(result, 'state') else None
|
||||||
if isinstance(result.act, torch.Tensor):
|
if isinstance(result.act, torch.Tensor):
|
||||||
self._act = result.act.detach().cpu().numpy()
|
self._act = result.act.detach().cpu().numpy()
|
||||||
|
@ -90,12 +90,15 @@ class DDPGPolicy(BasePolicy):
|
|||||||
return Batch(act=logits, state=h)
|
return Batch(act=logits, state=h)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1):
|
||||||
target_q = self.critic_old(batch.obs_next, self(
|
with torch.no_grad():
|
||||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
target_q = self.critic_old(batch.obs_next, self(
|
||||||
dev = target_q.device
|
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
dev = target_q.device
|
||||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
rew = torch.tensor(batch.rew,
|
||||||
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
|
dtype=torch.float, device=dev)[:, None]
|
||||||
|
done = torch.tensor(batch.done,
|
||||||
|
dtype=torch.float, device=dev)[:, None]
|
||||||
|
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||||
current_q = self.critic(batch.obs, batch.act)
|
current_q = self.critic(batch.obs, batch.act)
|
||||||
critic_loss = F.mse_loss(current_q, target_q)
|
critic_loss = F.mse_loss(current_q, target_q)
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
|
@ -62,17 +62,20 @@ class SACPolicy(DDPGPolicy):
|
|||||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1):
|
||||||
obs_next_result = self(batch, input='obs_next')
|
with torch.no_grad():
|
||||||
a_ = obs_next_result.act
|
obs_next_result = self(batch, input='obs_next')
|
||||||
dev = a_.device
|
a_ = obs_next_result.act
|
||||||
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
|
dev = a_.device
|
||||||
target_q = torch.min(
|
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
|
||||||
self.critic1_old(batch.obs_next, a_),
|
target_q = torch.min(
|
||||||
self.critic2_old(batch.obs_next, a_),
|
self.critic1_old(batch.obs_next, a_),
|
||||||
) - self._alpha * obs_next_result.log_prob
|
self.critic2_old(batch.obs_next, a_),
|
||||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
) - self._alpha * obs_next_result.log_prob
|
||||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
rew = torch.tensor(batch.rew,
|
||||||
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
|
dtype=torch.float, device=dev)[:, None]
|
||||||
|
done = torch.tensor(batch.done,
|
||||||
|
dtype=torch.float, device=dev)[:, None]
|
||||||
|
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||||
obs_result = self(batch)
|
obs_result = self(batch)
|
||||||
a = obs_result.act
|
a = obs_result.act
|
||||||
current_q1, current_q1a = self.critic1(
|
current_q1, current_q1a = self.critic1(
|
||||||
|
@ -51,19 +51,22 @@ class TD3Policy(DDPGPolicy):
|
|||||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1):
|
||||||
a_ = self(batch, model='actor_old', input='obs_next').act
|
with torch.no_grad():
|
||||||
dev = a_.device
|
a_ = self(batch, model='actor_old', input='obs_next').act
|
||||||
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
|
dev = a_.device
|
||||||
if self._noise_clip >= 0:
|
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
|
||||||
noise = noise.clamp(-self._noise_clip, self._noise_clip)
|
if self._noise_clip >= 0:
|
||||||
a_ += noise
|
noise = noise.clamp(-self._noise_clip, self._noise_clip)
|
||||||
a_ = a_.clamp(self._range[0], self._range[1])
|
a_ += noise
|
||||||
target_q = torch.min(
|
a_ = a_.clamp(self._range[0], self._range[1])
|
||||||
self.critic1_old(batch.obs_next, a_),
|
target_q = torch.min(
|
||||||
self.critic2_old(batch.obs_next, a_))
|
self.critic1_old(batch.obs_next, a_),
|
||||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
self.critic2_old(batch.obs_next, a_))
|
||||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
rew = torch.tensor(batch.rew,
|
||||||
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
|
dtype=torch.float, device=dev)[:, None]
|
||||||
|
done = torch.tensor(batch.done,
|
||||||
|
dtype=torch.float, device=dev)[:, None]
|
||||||
|
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||||
# critic 1
|
# critic 1
|
||||||
current_q1 = self.critic1(batch.obs, batch.act)
|
current_q1 = self.critic1(batch.obs, batch.act)
|
||||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user