diff --git a/.gitignore b/.gitignore index 4ceb487..28b6ccc 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,4 @@ dmypy.json # customize flake8.sh log/ +MUJOCO_LOG.TXT diff --git a/test/continuous/test_sac.py b/test/continuous/test_sac.py new file mode 100644 index 0000000..ae67d96 --- /dev/null +++ b/test/continuous/test_sac.py @@ -0,0 +1,114 @@ +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import SACPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer +from tianshou.env import VectorEnv, SubprocVectorEnv + +if __name__ == '__main__': + from net import ActorProb, Critic +else: # pytest + from test.continuous.net import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--actor-lr', type=float, default=3e-4) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=2400) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--layer-num', type=int, default=1) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_sac(args=get_args()): + env = gym.make(args.task) + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -250 + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # train_envs = gym.make(args.task) + train_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)], + reset_after_done=True) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)], + reset_after_done=False) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + actor = ActorProb( + args.layer_num, args.state_shape, args.action_shape, + args.max_action, args.device + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + critic1 = Critic( + args.layer_num, args.state_shape, args.action_shape, args.device + ).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic( + args.layer_num, args.state_shape, args.action_shape, args.device + ).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + policy = SACPolicy( + actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, + args.tau, args.gamma, args.alpha, + [env.action_space.low[0], env.action_space.high[0]], + reward_normalization=True) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size), 1) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.buffer_size) + # log + writer = SummaryWriter(args.logdir) + + def stop_fn(x): + return x >= env.spec.reward_threshold + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, stop_fn=stop_fn, writer=writer) + if args.task == 'Pendulum-v0': + assert stop_fn(result['best_reward']) + train_collector.close() + test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=1 / 35) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + + +if __name__ == '__main__': + test_sac() diff --git a/tianshou/policy/ddpg.py b/tianshou/policy/ddpg.py index 05d7f97..f9ea4ce 100644 --- a/tianshou/policy/ddpg.py +++ b/tianshou/policy/ddpg.py @@ -15,9 +15,10 @@ class DDPGPolicy(BasePolicy): tau=0.005, gamma=0.99, exploration_noise=0.1, action_range=None, reward_normalization=True): super().__init__() - self.actor, self.actor_old = actor, deepcopy(actor) - self.actor_old.eval() - self.actor_optim = actor_optim + if actor is not None: + self.actor, self.actor_old = actor, deepcopy(actor) + self.actor_old.eval() + self.actor_optim = actor_optim if critic is not None: self.critic, self.critic_old = critic, deepcopy(critic) self.critic_old.eval() @@ -28,7 +29,11 @@ class DDPGPolicy(BasePolicy): self._gamma = gamma assert 0 <= exploration_noise, 'noise should not be negative' self._eps = exploration_noise + assert action_range is not None self._range = action_range + self._action_bias = (action_range[0] + action_range[1]) / 2 + self._action_scale = (action_range[1] - action_range[0]) / 2 + # it is only a little difference to use rand_normal # self.noise = OUNoise() self._rew_norm = reward_normalization self.__eps = np.finfo(np.float32).eps.item() @@ -53,19 +58,27 @@ class DDPGPolicy(BasePolicy): self.critic_old.parameters(), self.critic.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + def process_fn(self, batch, buffer, indice): + if self._rew_norm: + self._rew_mean = buffer.rew.mean() + self._rew_std = buffer.rew.std() + return batch + def __call__(self, batch, state=None, model='actor', input='obs', eps=None): model = getattr(self, model) obs = getattr(batch, input) logits, h = model(obs, state=state, info=batch.info) + logits += self._action_bias if eps is None: eps = self._eps # noise = np.random.normal(0, eps, size=logits.shape) # noise = self.noise(logits.shape, eps) # logits += torch.tensor(noise, device=logits.device) - logits += torch.randn(size=logits.shape, device=logits.device) * eps - if self._range: - logits = logits.clamp(self._range[0], self._range[1]) + if eps > 0: + logits += torch.randn( + size=logits.shape, device=logits.device) * eps + logits = logits.clamp(self._range[0], self._range[1]) return Batch(act=logits, state=h) def learn(self, batch, batch_size=None, repeat=1): @@ -74,7 +87,7 @@ class DDPGPolicy(BasePolicy): dev = target_q.device rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] if self._rew_norm: - rew = (rew - rew.mean()) / (rew.std() + self.__eps) + rew = (rew - self._rew_mean) / (self._rew_std + self.__eps) done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] target_q = rew + ((1. - done) * self._gamma * target_q).detach() current_q = self.critic(batch.obs, batch.act) diff --git a/tianshou/policy/sac.py b/tianshou/policy/sac.py index 0394958..6bb815b 100644 --- a/tianshou/policy/sac.py +++ b/tianshou/policy/sac.py @@ -9,18 +9,98 @@ from tianshou.policy import DDPGPolicy class SACPolicy(DDPGPolicy): """docstring for SACPolicy""" - def __init__(self, actor, actor_optim, critic, critic_optim, - tau, gamma, ): - super().__init__() - self.actor, self.actor_old = actor, deepcopy(actor) - self.actor_old.eval() - self.actor_optim = actor_optim - self.critic, self.critic_old = critic, deepcopy(critic) - self.critic_old.eval() - self.critic_optim = critic_optim - def __call__(self, batch, state=None): - pass + def __init__(self, actor, actor_optim, critic1, critic1_optim, + critic2, critic2_optim, tau=0.005, gamma=0.99, + alpha=0.2, action_range=None, reward_normalization=True): + super().__init__(None, None, None, None, tau, gamma, 0, + action_range, reward_normalization) + self.actor, self.actor_optim = actor, actor_optim + self.critic1, self.critic1_old = critic1, deepcopy(critic1) + self.critic1_old.eval() + self.critic1_optim = critic1_optim + self.critic2, self.critic2_old = critic2, deepcopy(critic2) + self.critic2_old.eval() + self.critic2_optim = critic2_optim + self._alpha = alpha + self.__eps = np.finfo(np.float32).eps.item() + + def train(self): + self.training = True + self.actor.train() + self.critic1.train() + self.critic2.train() + + def eval(self): + self.training = False + self.actor.eval() + self.critic1.eval() + self.critic2.eval() + + def sync_weight(self): + for o, n in zip( + self.critic1_old.parameters(), self.critic1.parameters()): + o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + for o, n in zip( + self.critic2_old.parameters(), self.critic2.parameters()): + o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + + def __call__(self, batch, state=None, input='obs'): + obs = getattr(batch, input) + logits, h = self.actor(obs, state=state, info=batch.info) + assert isinstance(logits, tuple) + dist = torch.distributions.Normal(*logits) + + x = dist.rsample() + y = torch.tanh(x) + act = y * self._action_scale + self._action_bias + log_prob = dist.log_prob(x) - torch.log( + self._action_scale * (1 - y.pow(2)) + self.__eps) + act = act.clamp(self._range[0], self._range[1]) + return Batch( + logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) def learn(self, batch, batch_size=None, repeat=1): - pass + obs_next_result = self(batch, input='obs_next') + a_ = obs_next_result.act + dev = a_.device + batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) + target_q = torch.min( + self.critic1_old(batch.obs_next, a_), + self.critic2_old(batch.obs_next, a_), + ) - self._alpha * obs_next_result.log_prob + rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] + if self._rew_norm: + rew = (rew - self._rew_mean) / (self._rew_std + self.__eps) + done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] + target_q = rew + ((1. - done) * self._gamma * target_q).detach() + obs_result = self(batch) + a = obs_result.act + current_q1, current_q1a = self.critic1( + np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a]) + ).split(batch.obs.shape[0]) + current_q2, current_q2a = self.critic2( + np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a]) + ).split(batch.obs.shape[0]) + actor_loss = (self._alpha * obs_result.log_prob - torch.min( + current_q1a, current_q2a)).mean() + # critic 1 + critic1_loss = F.mse_loss(current_q1, target_q) + self.critic1_optim.zero_grad() + critic1_loss.backward(retain_graph=True) + self.critic1_optim.step() + # critic 2 + critic2_loss = F.mse_loss(current_q2, target_q) + self.critic2_optim.zero_grad() + critic2_loss.backward(retain_graph=True) + self.critic2_optim.step() + # actor + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + self.sync_weight() + return { + 'loss/actor': actor_loss.detach().cpu().numpy(), + 'loss/critic1': critic1_loss.detach().cpu().numpy(), + 'loss/critic2': critic2_loss.detach().cpu().numpy(), + } diff --git a/tianshou/policy/td3.py b/tianshou/policy/td3.py index 0dc1933..09f40c9 100644 --- a/tianshou/policy/td3.py +++ b/tianshou/policy/td3.py @@ -8,6 +8,7 @@ from tianshou.policy import DDPGPolicy class TD3Policy(DDPGPolicy): """docstring for TD3Policy""" + def __init__(self, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau=0.005, gamma=0.99, exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2, @@ -57,14 +58,13 @@ class TD3Policy(DDPGPolicy): if self._noise_clip >= 0: noise = noise.clamp(-self._noise_clip, self._noise_clip) a_ += noise - if self._range: - a_ = a_.clamp(self._range[0], self._range[1]) + a_ = a_.clamp(self._range[0], self._range[1]) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_)) rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] if self._rew_norm: - rew = (rew - rew.mean()) / (rew.std() + self.__eps) + rew = (rew - self._rew_mean) / (self._rew_std + self.__eps) done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] target_q = rew + ((1. - done) * self._gamma * target_q).detach() # critic 1