This commit is contained in:
Trinkle23897 2020-03-23 11:34:52 +08:00
parent a87563b8e6
commit 30a0fc079c
11 changed files with 254 additions and 13 deletions

View File

@ -79,7 +79,7 @@ def test_ddpg(args=get_args()):
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size), 1) policy, train_envs, ReplayBuffer(args.buffer_size), 1)
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)

View File

@ -86,7 +86,7 @@ def _test_ppo(args=get_args()):
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch) train_collector.collect(n_step=args.step_per_epoch)
# log # log
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)

118
test/continuous/test_td3.py Normal file
View File

@ -0,0 +1,118 @@
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
if __name__ == '__main__':
from net import Actor, Critic
else: # pytest
from test.continuous.net import Actor, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--exploration-noise', type=float, default=0.1)
parser.add_argument('--policy-noise', type=float, default=0.2)
parser.add_argument('--noise-clip', type=float, default=0.5)
parser.add_argument('--update-actor-freq', type=int, default=2)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_td3(args=get_args()):
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# train_envs = gym.make(args.task)
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
actor = Actor(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir)
def stop_fn(x):
return x >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
if args.task == 'Pendulum-v0':
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
if __name__ == '__main__':
test_td3()

View File

@ -73,7 +73,7 @@ def test_a2c(args=get_args()):
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)

View File

@ -66,8 +66,8 @@ def test_dqn(args=get_args()):
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.batch_size) train_collector.collect(n_step=args.buffer_size)
# log # log
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)

View File

@ -121,7 +121,7 @@ def test_pg(args=get_args()):
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)

View File

@ -78,7 +78,7 @@ def test_ppo(args=get_args()):
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)

View File

@ -4,6 +4,8 @@ from tianshou.policy.pg import PGPolicy
from tianshou.policy.a2c import A2CPolicy from tianshou.policy.a2c import A2CPolicy
from tianshou.policy.ddpg import DDPGPolicy from tianshou.policy.ddpg import DDPGPolicy
from tianshou.policy.ppo import PPOPolicy from tianshou.policy.ppo import PPOPolicy
from tianshou.policy.td3 import TD3Policy
from tianshou.policy.sac import SACPolicy
__all__ = [ __all__ = [
'BasePolicy', 'BasePolicy',
@ -12,4 +14,6 @@ __all__ = [
'A2CPolicy', 'A2CPolicy',
'DDPGPolicy', 'DDPGPolicy',
'PPOPolicy', 'PPOPolicy',
'TD3Policy',
'SACPolicy',
] ]

View File

@ -18,9 +18,10 @@ class DDPGPolicy(BasePolicy):
self.actor, self.actor_old = actor, deepcopy(actor) self.actor, self.actor_old = actor, deepcopy(actor)
self.actor_old.eval() self.actor_old.eval()
self.actor_optim = actor_optim self.actor_optim = actor_optim
self.critic, self.critic_old = critic, deepcopy(critic) if critic is not None:
self.critic_old.eval() self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_optim = critic_optim self.critic_old.eval()
self.critic_optim = critic_optim
assert 0 < tau <= 1, 'tau should in (0, 1]' assert 0 < tau <= 1, 'tau should in (0, 1]'
self._tau = tau self._tau = tau
assert 0 < gamma <= 1, 'gamma should in (0, 1]' assert 0 < gamma <= 1, 'gamma should in (0, 1]'
@ -45,9 +46,6 @@ class DDPGPolicy(BasePolicy):
self.actor.eval() self.actor.eval()
self.critic.eval() self.critic.eval()
def process_fn(self, batch, buffer, indice):
return batch
def sync_weight(self): def sync_weight(self):
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)

26
tianshou/policy/sac.py Normal file
View File

@ -0,0 +1,26 @@
import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import DDPGPolicy
class SACPolicy(DDPGPolicy):
"""docstring for SACPolicy"""
def __init__(self, actor, actor_optim, critic, critic_optim,
tau, gamma, ):
super().__init__()
self.actor, self.actor_old = actor, deepcopy(actor)
self.actor_old.eval()
self.actor_optim = actor_optim
self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_old.eval()
self.critic_optim = critic_optim
def __call__(self, batch, state=None):
pass
def learn(self, batch, batch_size=None, repeat=1):
pass

95
tianshou/policy/td3.py Normal file
View File

@ -0,0 +1,95 @@
import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from tianshou.policy import DDPGPolicy
class TD3Policy(DDPGPolicy):
"""docstring for TD3Policy"""
def __init__(self, actor, actor_optim, critic1, critic1_optim,
critic2, critic2_optim, tau=0.005, gamma=0.99,
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
noise_clip=0.5, action_range=None, reward_normalization=True):
super().__init__(actor, actor_optim, None, None,
tau, gamma, exploration_noise, action_range,
reward_normalization)
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
self.critic1_optim = critic1_optim
self.critic2, self.critic2_old = critic2, deepcopy(critic2)
self.critic2_old.eval()
self.critic2_optim = critic2_optim
self._policy_noise = policy_noise
self._freq = update_actor_freq
self._noise_clip = noise_clip
self._cnt = 0
self._last = 0
self.__eps = np.finfo(np.float32).eps.item()
def train(self):
self.training = True
self.actor.train()
self.critic1.train()
self.critic2.train()
def eval(self):
self.training = False
self.actor.eval()
self.critic1.eval()
self.critic2.eval()
def sync_weight(self):
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
for o, n in zip(
self.critic1_old.parameters(), self.critic1.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
for o, n in zip(
self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def learn(self, batch, batch_size=None, repeat=1):
a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
if self._noise_clip >= 0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
if self._range:
a_ = a_.clamp(self._range[0], self._range[1])
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_))
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
if self._rew_norm:
rew = (rew - rew.mean()) / (rew.std() + self.__eps)
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act)
critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
if self._cnt % self._freq == 0:
actor_loss = -self.critic1(
batch.obs, self(batch, eps=0).act).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self._last = actor_loss.detach().cpu().numpy()
self.actor_optim.step()
self.sync_weight()
self._cnt += 1
return {
'loss/actor': self._last,
'loss/critic1': critic1_loss.detach().cpu().numpy(),
'loss/critic2': critic2_loss.detach().cpu().numpy(),
}