diff --git a/README.md b/README.md index 0807fd8..a39da24 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) - [Double DQN](https://arxiv.org/pdf/1509.06461.pdf) - [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf) +- [Branching DQN](https://arxiv.org/pdf/1711.08946.pdf) - [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) - [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf) - [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf) diff --git a/docs/api/tianshou.env.rst b/docs/api/tianshou.env.rst index 77713f4..bf0c8e0 100644 --- a/docs/api/tianshou.env.rst +++ b/docs/api/tianshou.env.rst @@ -49,6 +49,14 @@ RayVectorEnv Wrapper ------- +ContinuousToDiscrete +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.ContinuousToDiscrete + :members: + :undoc-members: + :show-inheritance: + VectorEnvWrapper ~~~~~~~~~~~~~~~~ diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index 2719757..fe129d8 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -25,6 +25,11 @@ DQN Family :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.BranchingDQNPolicy + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: tianshou.policy.C51Policy :members: :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index 49453ad..c187827 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,6 +12,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network `_ * :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ * :class:`~tianshou.policy.DQNPolicy` `Dueling DQN `_ +* :class:`~tianshou.policy.BranchingDQNPolicy` `Branching DQN `_ * :class:`~tianshou.policy.C51Policy` `Categorical DQN `_ * :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN `_ * :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN `_ diff --git a/examples/box2d/README.md b/examples/box2d/README.md index f438b2a..f42f19a 100644 --- a/examples/box2d/README.md +++ b/examples/box2d/README.md @@ -4,3 +4,10 @@ - If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward) ![](results/sac/BipedalHardcore.png) + + +# BipedalWalker-BDQ + +- To demonstrate the cpabilities of the BDQ to scale up to big discrete action spaces, we run it on a discretized version of the BipedalWalker-v3 environment, where the number of possible actions in each dimension is 25, for a total of 25^4 = 390 625 possible actions. A usaual DQN architecture would use 25^4 output neurons for the Q-network, thus scaling exponentially with the number of action space dimensions, while the Branching architecture scales linearly and uses only 25*4 output neurons. + +![](results/bdq/BipedalWalker.png) \ No newline at end of file diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py new file mode 100644 index 0000000..bb9d4d2 --- /dev/null +++ b/examples/box2d/bipedal_bdq.py @@ -0,0 +1,163 @@ +import argparse +import datetime +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv +from tianshou.policy import BranchingDQNPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import BranchingNet + + +def get_args(): + parser = argparse.ArgumentParser() + # task + parser.add_argument("--task", type=str, default="BipedalWalker-v3") + # network architecture + parser.add_argument( + "--common-hidden-sizes", type=int, nargs="*", default=[512, 256] + ) + parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--action-per-branch", type=int, default=25) + # training hyperparameters + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.) + parser.add_argument("--eps-train", type=float, default=0.73) + parser.add_argument("--eps-decay", type=float, default=5e-6) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--target-update-freq", type=int, default=1000) + parser.add_argument("--epoch", type=int, default=1000) + parser.add_argument("--step-per-epoch", type=int, default=80000) + parser.add_argument("--step-per-collect", type=int, default=16) + parser.add_argument("--update-per-step", type=float, default=0.0625) + parser.add_argument("--batch-size", type=int, default=512) + parser.add_argument("--training-num", type=int, default=20) + parser.add_argument("--test-num", type=int, default=10) + # other + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + return parser.parse_args() + + +def test_bdq(args=get_args()): + env = gym.make(args.task) + env = ContinuousToDiscrete(env, args.action_per_branch) + + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.num_branches = args.action_shape if isinstance(args.action_shape, + int) else args.action_shape[0] + + print("Observations shape:", args.state_shape) + print("Num branches:", args.num_branches) + print("Actions per branch:", args.action_per_branch) + + # train_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = SubprocVectorEnv( + [ + lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) + for _ in range(args.training_num) + ] + ) + # test_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) + test_envs = SubprocVectorEnv( + [ + lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) + for _ in range(args.test_num) + ] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = BranchingNet( + args.state_shape, + args.num_branches, + args.action_per_branch, + args.common_hidden_sizes, + args.value_hidden_sizes, + args.action_hidden_sizes, + device=args.device, + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = BranchingDQNPolicy( + net, optim, args.gamma, target_update_freq=args.target_update_freq + ) + # collector + train_collector = Collector( + policy, + train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True + ) + test_collector = Collector(policy, test_envs, exploration_noise=False) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + log_path = os.path.join(args.logdir, "bdq", args.task, current_time) + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_best_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards): + return mean_rewards >= getattr(env.spec.reward_threshold) + + def train_fn(epoch, env_step): # exp decay + eps = max(args.eps_train * (1 - args.eps_decay)**env_step, args.eps_test) + policy.set_eps(eps) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + # stop_fn=stop_fn, + train_fn=train_fn, + test_fn=test_fn, + save_best_fn=save_best_fn, + logger=logger + ) + + # assert stop_fn(result["best_reward"]) + if __name__ == "__main__": + pprint.pprint(result) + # Let's watch its performance! + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == "__main__": + test_bdq(get_args()) diff --git a/examples/box2d/results/bdq/BipedalWalker.png b/examples/box2d/results/bdq/BipedalWalker.png new file mode 100644 index 0000000..ee4b240 Binary files /dev/null and b/examples/box2d/results/bdq/BipedalWalker.png differ diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py new file mode 100644 index 0000000..604fd88 --- /dev/null +++ b/test/discrete/test_bdq.py @@ -0,0 +1,150 @@ +import argparse +import pprint + +import gym +import numpy as np +import torch + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import ContinuousToDiscrete, DummyVectorEnv +from tianshou.policy import BranchingDQNPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils.net.common import BranchingNet + + +def get_args(): + parser = argparse.ArgumentParser() + # task + parser.add_argument("--task", type=str, default="Pendulum-v1") + parser.add_argument('--reward-threshold', type=float, default=None) + # network architecture + parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--action-per-branch", type=int, default=40) + # training hyperparameters + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.01) + parser.add_argument("--eps-train", type=float, default=0.76) + parser.add_argument("--eps-decay", type=float, default=1e-4) + parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--gamma", type=float, default=0.9) + parser.add_argument("--target-update-freq", type=int, default=200) + parser.add_argument("--epoch", type=int, default=10) + parser.add_argument("--step-per-epoch", type=int, default=80000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + args = parser.parse_known_args()[0] + return args + + +def test_bdq(args=get_args()): + env = gym.make(args.task) + env = ContinuousToDiscrete(env, args.action_per_branch) + + args.state_shape = env.observation_space.shape or env.observation_space.n + args.num_branches = env.action_space.shape[0] + + if args.reward_threshold is None: + default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} + args.reward_threshold = default_reward_threshold.get( + args.task, env.spec.reward_threshold + ) + + print("Observations shape:", args.state_shape) + print("Num branches:", args.num_branches) + print("Actions per branch:", args.action_per_branch) + + train_envs = DummyVectorEnv( + [ + lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) + for _ in range(args.training_num) + ] + ) + test_envs = DummyVectorEnv( + [ + lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) + for _ in range(args.test_num) + ] + ) + + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = BranchingNet( + args.state_shape, + args.num_branches, + args.action_per_branch, + args.common_hidden_sizes, + args.value_hidden_sizes, + args.action_hidden_sizes, + device=args.device, + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = BranchingDQNPolicy( + net, optim, args.gamma, target_update_freq=args.target_update_freq + ) + # collector + train_collector = Collector( + policy, + train_envs, + VectorReplayBuffer(args.buffer_size, args.training_num), + exploration_noise=True + ) + test_collector = Collector(policy, test_envs, exploration_noise=False) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + + def train_fn(epoch, env_step): # exp decay + eps = max(args.eps_train * (1 - args.eps_decay)**env_step, args.eps_test) + policy.set_eps(eps) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + def stop_fn(mean_rewards): + return mean_rewards >= args.reward_threshold + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + ) + + # assert stop_fn(result["best_reward"]) + if __name__ == "__main__": + pprint.pprint(result) + # Let's watch its performance! + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == "__main__": + test_bdq(get_args()) diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index e23143b..ed705e2 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -1,7 +1,7 @@ import pickle from copy import deepcopy from numbers import Number -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, no_type_check import h5py import numpy as np @@ -10,6 +10,7 @@ import torch from tianshou.data.batch import Batch, _parse_value +@no_type_check def to_numpy(x: Any) -> Union[Batch, np.ndarray]: """Return an object without torch.Tensor.""" if isinstance(x, torch.Tensor): # most often case @@ -30,6 +31,7 @@ def to_numpy(x: Any) -> Union[Batch, np.ndarray]: return np.asanyarray(x) +@no_type_check def to_torch( x: Any, dtype: Optional[torch.dtype] = None, @@ -39,14 +41,14 @@ def to_torch( if isinstance(x, np.ndarray) and issubclass( x.dtype.type, (np.bool_, np.number) ): # most often case - x = torch.from_numpy(x).to(device) # type: ignore + x = torch.from_numpy(x).to(device) if dtype is not None: x = x.type(dtype) return x elif isinstance(x, torch.Tensor): # second often case if dtype is not None: x = x.type(dtype) - return x.to(device) # type: ignore + return x.to(device) elif isinstance(x, (np.number, np.bool_, Number)): return to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (dict, Batch)): @@ -59,6 +61,7 @@ def to_torch( raise TypeError(f"object {x} cannot be converted to torch.") +@no_type_check def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]: """Return an object without np.ndarray. diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 66fbc68..8b1c71a 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,5 +1,6 @@ """Env package.""" +from tianshou.env.gym_wrappers import ContinuousToDiscrete from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper from tianshou.env.venvs import ( BaseVectorEnv, @@ -23,4 +24,5 @@ __all__ = [ "VectorEnvWrapper", "VectorEnvNormObs", "PettingZooEnv", + "ContinuousToDiscrete", ] diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py new file mode 100644 index 0000000..f63bc9e --- /dev/null +++ b/tianshou/env/gym_wrappers.py @@ -0,0 +1,28 @@ +import gym +import numpy as np + + +class ContinuousToDiscrete(gym.ActionWrapper): + """Gym environment wrapper to take discrete action in a continuous environment. + + :param gym.Env env: gym environment with continuous action space. + :param int action_per_branch: number of discrete actions in each dimension + of the action space. + """ + + def __init__(self, env: gym.Env, action_per_branch: int) -> None: + super().__init__(env) + assert isinstance(env.action_space, gym.spaces.Box) + low, high = env.action_space.low, env.action_space.high + num_branches = env.action_space.shape[0] + self.action_space = gym.spaces.MultiDiscrete( + [action_per_branch] * num_branches + ) + mesh = [] + for lo, hi in zip(low, high): + mesh.append(np.linspace(lo, hi, action_per_branch)) + self.mesh = np.array(mesh) + + def action(self, act: np.ndarray) -> np.ndarray: + # modify act + return np.array([self.mesh[i][a] for i, a in enumerate(act)]) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index f7774cb..8007bad 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -4,6 +4,7 @@ from tianshou.policy.base import BasePolicy from tianshou.policy.random import RandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.bdq import BranchingDQNPolicy from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.modelfree.rainbow import RainbowPolicy from tianshou.policy.modelfree.qrdqn import QRDQNPolicy @@ -34,6 +35,7 @@ __all__ = [ "BasePolicy", "RandomPolicy", "DQNPolicy", + "BranchingDQNPolicy", "C51Policy", "RainbowPolicy", "QRDQNPolicy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 325a85b..893905d 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -339,7 +339,7 @@ class BasePolicy(ABC, nn.Module): assert np.isclose(gae_lambda, 1.0) v_s_ = np.zeros_like(rew) else: - v_s_ = to_numpy(v_s_.flatten()) # type: ignore + v_s_ = to_numpy(v_s_.flatten()) v_s_ = v_s_ * BasePolicy.value_mask(buffer, indices) v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten()) diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 20c6a62..211c093 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -54,11 +54,11 @@ class ImitationPolicy(BasePolicy): if self.action_type == "continuous": # regression act = self(batch).act act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) - loss = F.mse_loss(act, act_target) # type: ignore + loss = F.mse_loss(act, act_target) elif self.action_type == "discrete": # classification act = F.log_softmax(self(batch).logits, dim=-1) act_target = to_torch(batch.act, dtype=torch.long, device=act.device) - loss = F.nll_loss(act, act_target) # type: ignore + loss = F.nll_loss(act, act_target) loss.backward() self.optim.step() return {"loss": loss.item()} diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 89facd7..883400d 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -104,9 +104,7 @@ class BCQPolicy(BasePolicy): """Compute action over the given batch data.""" # There is "obs" in the Batch # obs_group: several groups. Each group has a state. - obs_group: torch.Tensor = to_torch( # type: ignore - batch.obs, device=self.device - ) + obs_group: torch.Tensor = to_torch(batch.obs, device=self.device) act_group = [] for obs in obs_group: # now obs is (state_dim) @@ -132,9 +130,7 @@ class BCQPolicy(BasePolicy): def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # batch: obs, act, rew, done, obs_next. (numpy array) # (batch_size, state_dim) - batch: Batch = to_torch( # type: ignore - batch, dtype=torch.float, device=self.device - ) + batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) obs, act = batch.obs, batch.act batch_size = obs.shape[0] diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index dd890ab..7d9cea9 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -160,9 +160,7 @@ class CQLPolicy(SACPolicy): return batch def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - batch: Batch = to_torch( # type: ignore - batch, dtype=torch.float, device=self.device, - ) + batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next batch_size = obs.shape[0] @@ -208,12 +206,12 @@ class CQLPolicy(SACPolicy): random_actions = torch.FloatTensor( batch_size * self.num_repeat_actions, act.shape[-1] ).uniform_(-self.min_action, self.max_action).to(self.device) - tmp_obs = obs.unsqueeze(1) \ - .repeat(1, self.num_repeat_actions, 1) \ - .view(batch_size * self.num_repeat_actions, obs.shape[-1]) - tmp_obs_next = obs_next.unsqueeze(1) \ - .repeat(1, self.num_repeat_actions, 1) \ - .view(batch_size * self.num_repeat_actions, obs.shape[-1]) + + obs_len = len(obs.shape) + repeat_size = [1, self.num_repeat_actions] + [1] * (obs_len - 1) + view_size = [batch_size * self.num_repeat_actions] + list(obs.shape[1:]) + tmp_obs = obs.unsqueeze(1).repeat(*repeat_size).view(*view_size) + tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size) # tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim) current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 95b5dea..1f713b5 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -113,10 +113,7 @@ class DiscreteBCQPolicy(DQNPolicy): current_q = result.q_value[np.arange(len(target_q)), batch.act] act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) - i_loss = F.nll_loss( - F.log_softmax(imitation_logits, dim=-1), - act # type: ignore - ) + i_loss = F.nll_loss(F.log_softmax(imitation_logits, dim=-1), act) reg_loss = imitation_logits.pow(2).mean() loss = q_loss + i_loss + self._weight_reg * reg_loss diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 434e12e..b85f52d 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -108,9 +108,9 @@ class GAILPolicy(PPOPolicy): return super().process_fn(batch, buffer, indices) def disc(self, batch: Batch) -> torch.Tensor: - obs = to_torch(batch.obs, device=self.disc_net.device) # type: ignore - act = to_torch(batch.act, device=self.disc_net.device) # type: ignore - return self.disc_net(torch.cat([obs, act], dim=1)) # type: ignore + obs = to_torch(batch.obs, device=self.disc_net.device) + act = to_torch(batch.act, device=self.disc_net.device) + return self.disc_net(torch.cat([obs, act], dim=1)) def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 66ab1bb..97fcde2 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -105,7 +105,7 @@ class ICMPolicy(BasePolicy): self.optim.zero_grad() act_hat = batch.policy.act_hat act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) - inverse_loss = F.cross_entropy(act_hat, act).mean() # type: ignore + inverse_loss = F.cross_entropy(act_hat, act).mean() forward_loss = batch.policy.mse_loss.mean() loss = ( (1 - self.forward_loss_weight) * inverse_loss + diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py new file mode 100644 index 0000000..b02514c --- /dev/null +++ b/tianshou/policy/modelfree/bdq.py @@ -0,0 +1,144 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch + +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as +from tianshou.policy import DQNPolicy +from tianshou.utils.net.common import BranchingNet + + +class BranchingDQNPolicy(DQNPolicy): + """Implementation of the Branching dual Q network arXiv:1711.08946. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. + :param int estimation_step: the number of steps to look ahead. Default to 1. + :param int target_update_freq: the target network update frequency (0 if + you do not use the target network). Default to 0. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + :param bool is_double: use double network. Default to True. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + model: BranchingNet, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + model, optim, discount_factor, estimation_step, target_update_freq, + reward_normalization, is_double + ) + assert estimation_step == 1, "N-step bigger than one is not supported by BDQ" + self.max_action_num = model.action_per_branch + self.num_branches = model.num_branches + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + batch = buffer[indices] # batch.obs_next: s_{t+n} + result = self(batch, input="obs_next") + if self._target: + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + target_q = self(batch, model="model_old", input="obs_next").logits + else: + target_q = result.logits + if self._is_double: + act = np.expand_dims(self(batch, input="obs_next").act, -1) + act = to_torch(act, dtype=torch.long, device=target_q.device) + else: + act = target_q.max(-1).indices.unsqueeze(-1) + return torch.gather(target_q, -1, act).squeeze() + + def _compute_return( + self, + batch: Batch, + buffer: ReplayBuffer, + indice: np.ndarray, + gamma: float = 0.99, + ) -> Batch: + rew = batch.rew + with torch.no_grad(): + target_q_torch = self._target_q(buffer, indice) # (bsz, ?) + target_q = to_numpy(target_q_torch) + end_flag = buffer.done.copy() + end_flag[buffer.unfinished_index()] = True + end_flag = end_flag[indice] + mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q + _target_q = rew + gamma * mean_target_q * (1 - end_flag) + target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1) + target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1) + + batch.returns = to_torch_as(target_q, target_q_torch) + if hasattr(batch, "weight"): # prio buffer update + batch.weight = to_torch_as(batch.weight, target_q_torch) + return batch + + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + ) -> Batch: + """Compute the 1-step return for BDQ targets.""" + return self._compute_return(batch, buffer, indices) + + def forward( + self, + batch: Batch, + state: Optional[Union[Dict, Batch, np.ndarray]] = None, + model: str = "model", + input: str = "obs", + **kwargs: Any, + ) -> Batch: + model = getattr(self, model) + obs = batch[input] + obs_next = obs.obs if hasattr(obs, "obs") else obs + logits, hidden = model(obs_next, state=state, info=batch.info) + act = to_numpy(logits.max(dim=-1)[1]) + return Batch(logits=logits, act=act, state=hidden) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + if self._target and self._iter % self._freq == 0: + self.sync_weight() + self.optim.zero_grad() + weight = batch.pop("weight", 1.0) + act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) + q = self(batch).logits + act_mask = torch.zeros_like(q) + act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1) + act_q = q * act_mask + returns = batch.returns + returns = returns * act_mask + td_error = returns - act_q + loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean() + batch.weight = td_error.sum(-1).sum(-1) # prio-buffer + loss.backward() + self.optim.step() + self._iter += 1 + return {"loss": loss.item()} + + def exploration_noise( + self, + act: Union[np.ndarray, Batch], + batch: Batch, + ) -> Union[np.ndarray, Batch]: + if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): + bsz = len(act) + rand_mask = np.random.rand(bsz) < self.eps + rand_act = np.random.randint( + low=0, high=self.max_action_num, size=(bsz, act.shape[-1]) + ) + if hasattr(batch.obs, "mask"): + rand_act += batch.obs.mask + act[rand_mask] = rand_act[rand_mask] + return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 593de15..e36b7f5 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -138,8 +138,6 @@ class DQNPolicy(BasePolicy): ... ) - :param float eps: in [0, 1], for epsilon-greedy exploration method. - :return: A :class:`~tianshou.data.Batch` which has 3 keys: * ``act`` the action. @@ -177,8 +175,11 @@ class DQNPolicy(BasePolicy): self._iter += 1 return {"loss": loss.item()} - def exploration_noise(self, act: Union[np.ndarray, Batch], - batch: Batch) -> Union[np.ndarray, Batch]: + def exploration_noise( + self, + act: Union[np.ndarray, Batch], + batch: Batch, + ) -> Union[np.ndarray, Batch]: if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 64283b4..1fc58f7 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -363,3 +363,93 @@ class EnsembleLinear(nn.Module): if self.bias is not None: x = x + self.bias return x + + +class BranchingNet(nn.Module): + """Branching dual Q network. + + Network for the BranchingDQNPolicy, it uses a common network module, a value module + and action "branches" one for each dimension.It allows for a linear scaling + of Q-value the output w.r.t. the number of dimensions in the action space. + For more info please refer to: arXiv:1711.08946. + :param state_shape: int or a sequence of int of the shape of state. + :param action_shape: int or a sequence of int of the shape of action. + :param action_peer_branch: int or a sequence of int of the number of actions in + each dimension. + :param common_hidden_sizes: shape of the common MLP network passed in as a list. + :param value_hidden_sizes: shape of the value MLP network passed in as a list. + :param action_hidden_sizes: shape of the action MLP network passed in as a list. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + :param device: specify the device when the network actually runs. Default + to "cpu". + :param bool softmax: whether to apply a softmax layer over the last layer's + output. + """ + + def __init__( + self, + state_shape: Union[int, Sequence[int]], + num_branches: int = 0, + action_per_branch: int = 2, + common_hidden_sizes: List[int] = [], + value_hidden_sizes: List[int] = [], + action_hidden_sizes: List[int] = [], + norm_layer: Optional[ModuleType] = None, + activation: Optional[ModuleType] = nn.ReLU, + device: Union[str, int, torch.device] = "cpu", + ) -> None: + super().__init__() + self.device = device + self.num_branches = num_branches + self.action_per_branch = action_per_branch + # common network + common_input_dim = int(np.prod(state_shape)) + common_output_dim = 0 + self.common = MLP( + common_input_dim, common_output_dim, common_hidden_sizes, norm_layer, + activation, device + ) + # value network + value_input_dim = common_hidden_sizes[-1] + value_output_dim = 1 + self.value = MLP( + value_input_dim, value_output_dim, value_hidden_sizes, norm_layer, + activation, device + ) + # action branching network + action_input_dim = common_hidden_sizes[-1] + action_output_dim = action_per_branch + self.branches = nn.ModuleList( + [ + MLP( + action_input_dim, action_output_dim, action_hidden_sizes, + norm_layer, activation, device + ) for _ in range(self.num_branches) + ] + ) + + def forward( + self, + obs: Union[np.ndarray, torch.Tensor], + state: Any = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: + """Mapping: obs -> model -> logits.""" + common_out = self.common(obs) + value_out = self.value(common_out) + value_out = torch.unsqueeze(value_out, 1) + action_out = [] + for b in self.branches: + action_out.append(b(common_out)) + action_scores = torch.stack(action_out, 1) + action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True) + logits = value_out + action_scores + return logits, state