diff --git a/.gitignore b/.gitignore index 0ecb650..604f2df 100644 --- a/.gitignore +++ b/.gitignore @@ -137,7 +137,6 @@ dmypy.json .pytype/ # customize -flake8.sh log/ MUJOCO_LOG.TXT *.pth diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index d73e1f5..9ca6845 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument('--noise_std', type=float, default=1.2) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) - parser.add_argument('--auto_alpha', type=bool, default=True) + parser.add_argument('--auto_alpha', type=int, default=1) parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py new file mode 100644 index 0000000..865924a --- /dev/null +++ b/test/discrete/test_sac.py @@ -0,0 +1,117 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import SubprocVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer +from tianshou.policy import DiscreteSACPolicy +from tianshou.utils.net.discrete import Actor, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--actor-lr', type=float, default=3e-4) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument('--gamma', type=float, default=0.95) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--alpha', type=float, default=0.05) + parser.add_argument('--auto_alpha', type=int, default=0) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=5) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=1) + parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.0) + parser.add_argument('--rew-norm', type=int, default=0) + parser.add_argument('--ignore-done', type=int, default=0) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_discrete_sac(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = Actor(net, args.action_shape, softmax_output=False).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net = Net(args.layer_num, args.state_shape, device=args.device) + critic1 = Critic(net, last_size=args.action_shape).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + net = Net(args.layer_num, args.state_shape, device=args.device) + critic2 = Critic(net, last_size=args.action_shape).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + # better not to use auto alpha in CartPole + if args.auto_alpha: + target_entropy = 0.98 * np.log(np.prod(args.action_shape)) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = DiscreteSACPolicy( + actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, + args.tau, args.gamma, args.alpha, + reward_normalization=args.rew_norm, + ignore_done=args.ignore_done) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # train_collector.collect(n_step=args.buffer_size) + # log + log_path = os.path.join(args.logdir, args.task, 'discrete_sac') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(x): + return x >= env.spec.reward_threshold + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, + test_in_train=False) + assert stop_fn(result['best_reward']) + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + + +if __name__ == '__main__': + test_discrete_sac() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 7383390..cf1f4de 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -8,6 +8,7 @@ from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.modelfree.ppo import PPOPolicy from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -22,5 +23,6 @@ __all__ = [ "PPOPolicy", "TD3Policy", "SACPolicy", + "DiscreteSACPolicy", "MultiAgentPolicyManager", ] diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index f19f56e..81bf7f6 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -140,7 +140,7 @@ class DDPGPolicy(BasePolicy): more detailed explanation. """ model = getattr(self, model) - obs = getattr(batch, input) + obs = batch[input] actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias if self._noise and self.training and explorating: diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py new file mode 100644 index 0000000..d9fa35e --- /dev/null +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -0,0 +1,148 @@ +import torch +import numpy as np +from torch.distributions import Categorical +from typing import Any, Dict, Tuple, Union, Optional + +from tianshou.policy import SACPolicy +from tianshou.data import Batch, ReplayBuffer, to_torch + + +class DiscreteSACPolicy(SACPolicy): + """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critic1: the first critic network. (s -> Q(s)) + :param torch.optim.Optimizer critic1_optim: the optimizer for the first + critic network. + :param torch.nn.Module critic2: the second critic network. (s -> Q(s)) + :param torch.optim.Optimizer critic2_optim: the optimizer for the second + critic network. + :param float tau: param for soft update of the target network, defaults to + 0.005. + :param float gamma: discount factor, in [0, 1], defaults to 0.99. + :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy + regularization coefficient, default to 0.2. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then + alpha is automatatically tuned. + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to ``False``. + :param bool ignore_done: ignore the done flag while training the policy, + defaults to ``False``. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic1: torch.nn.Module, + critic1_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + tau: float = 0.005, + gamma: float = 0.99, + alpha: Union[ + float, Tuple[float, torch.Tensor, torch.optim.Optimizer] + ] = 0.2, + reward_normalization: bool = False, + ignore_done: bool = False, + estimation_step: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(actor, actor_optim, critic1, critic1_optim, critic2, + critic2_optim, (-np.inf, np.inf), tau, gamma, alpha, + reward_normalization, ignore_done, estimation_step, + **kwargs) + self._alpha: Union[float, torch.Tensor] + + def forward( # type: ignore + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + input: str = "obs", + **kwargs: Any, + ) -> Batch: + obs = batch[input] + logits, h = self.actor(obs, state=state, info=batch.info) + dist = Categorical(logits=logits) + act = dist.sample() + return Batch(logits=logits, act=act, state=h, dist=dist) + + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: + batch = buffer[indice] # batch.obs: s_{t+n} + with torch.no_grad(): + obs_next_result = self(batch, input="obs_next") + dist = obs_next_result.dist + target_q = dist.probs * torch.min( + self.critic1_old(batch.obs_next), + self.critic2_old(batch.obs_next), + ) + target_q = target_q.sum(dim=-1) + self._alpha * dist.entropy() + return target_q + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + weight = batch.pop("weight", 1.0) + target_q = batch.returns.flatten() + act = to_torch( + batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) + + # critic 1 + current_q1 = self.critic1(batch.obs).gather(1, act).flatten() + td1 = current_q1 - target_q + critic1_loss = (td1.pow(2) * weight).mean() + + self.critic1_optim.zero_grad() + critic1_loss.backward() + self.critic1_optim.step() + + # critic 2 + current_q2 = self.critic2(batch.obs).gather(1, act).flatten() + td2 = current_q2 - target_q + critic2_loss = (td2.pow(2) * weight).mean() + + self.critic2_optim.zero_grad() + critic2_loss.backward() + self.critic2_optim.step() + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + dist = self(batch).dist + entropy = dist.entropy() + with torch.no_grad(): + current_q1a = self.critic1(batch.obs) + current_q2a = self.critic2(batch.obs) + q = torch.min(current_q1a, current_q2a) + actor_loss = -(self._alpha * entropy + + (dist.probs * q).sum(dim=-1)).mean() + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + + if self._is_auto_alpha: + log_prob = entropy.detach() - self._target_entropy + alpha_loss = (self._log_alpha * log_prob).mean() + self._alpha_optim.zero_grad() + alpha_loss.backward() + self._alpha_optim.step() + self._alpha = self._log_alpha.detach().exp() + + self.sync_weight() + + result = { + "loss/actor": actor_loss.item(), + "loss/critic1": critic1_loss.item(), + "loss/critic2": critic2_loss.item(), + } + if self._is_auto_alpha: + result["loss/alpha"] = alpha_loss.item() + result["alpha"] = self._alpha.item() # type: ignore + + return result diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 070f11d..71d16f6 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -143,7 +143,7 @@ class DQNPolicy(BasePolicy): more detailed explanation. """ model = getattr(self, model) - obs = getattr(batch, input) + obs = batch[input] obs_ = obs.obs if hasattr(obs, "obs") else obs q, h = model(obs_, state=state, info=batch.info) act: np.ndarray = to_numpy(q.max(dim=1)[1]) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 83c7150..e44f8a1 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -28,8 +28,6 @@ class SACPolicy(DDPGPolicy): :param float tau: param for soft update of the target network, defaults to 0.005. :param float gamma: discount factor, in [0, 1], defaults to 0.99. - :param float exploration_noise: the noise intensity, add to the action, - defaults to 0.1. :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy regularization coefficient, default to 0.2. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then @@ -38,8 +36,8 @@ class SACPolicy(DDPGPolicy): defaults to False. :param bool ignore_done: ignore the done flag while training the policy, defaults to False. - :param BaseNoise exploration_noise: add a noise to action for exploration. - This is useful when solving hard-exploration problem. + :param BaseNoise exploration_noise: add a noise to action for exploration, + defaults to None. This is useful when solving hard-exploration problem. .. seealso:: @@ -115,7 +113,7 @@ class SACPolicy(DDPGPolicy): explorating: bool = True, **kwargs: Any, ) -> Batch: - obs = getattr(batch, input) + obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) @@ -147,6 +145,7 @@ class SACPolicy(DDPGPolicy): def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) + # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() @@ -156,6 +155,7 @@ class SACPolicy(DDPGPolicy): self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() + # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q @@ -165,6 +165,7 @@ class SACPolicy(DDPGPolicy): critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer + # actor obs_result = self(batch, explorating=False) a = obs_result.act @@ -193,5 +194,6 @@ class SACPolicy(DDPGPolicy): } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() - result["v/alpha"] = self._alpha.item() # type: ignore + result["alpha"] = self._alpha.item() # type: ignore + return result diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 58e44a7..547593f 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -19,10 +19,12 @@ class Actor(nn.Module): preprocess_net: nn.Module, action_shape: Sequence[int], hidden_layer_size: int = 128, + softmax_output: bool = True, ) -> None: super().__init__() self.preprocess = preprocess_net self.last = nn.Linear(hidden_layer_size, np.prod(action_shape)) + self.softmax_output = softmax_output def forward( self, @@ -32,7 +34,9 @@ class Actor(nn.Module): ) -> Tuple[torch.Tensor, Any]: r"""Mapping: s -> Q(s, \*).""" logits, h = self.preprocess(s, state) - logits = F.softmax(self.last(logits), dim=-1) + logits = self.last(logits) + if self.softmax_output: + logits = F.softmax(logits, dim=-1) return logits, h @@ -44,11 +48,14 @@ class Critic(nn.Module): """ def __init__( - self, preprocess_net: nn.Module, hidden_layer_size: int = 128 + self, + preprocess_net: nn.Module, + hidden_layer_size: int = 128, + last_size: int = 1 ) -> None: super().__init__() self.preprocess = preprocess_net - self.last = nn.Linear(hidden_layer_size, 1) + self.last = nn.Linear(hidden_layer_size, last_size) def forward( self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any