Add BranchingDQN for large discrete action spaces (#618)

This commit is contained in:
Anas BELFADIL 2022-05-15 15:40:32 +02:00 committed by GitHub
parent a03f19af72
commit 53e6b0408d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 629 additions and 33 deletions

View File

@ -12,6 +12,7 @@
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [Branching DQN](https://arxiv.org/pdf/1711.08946.pdf)
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)

View File

@ -49,6 +49,14 @@ RayVectorEnv
Wrapper
-------
ContinuousToDiscrete
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: tianshou.env.ContinuousToDiscrete
:members:
:undoc-members:
:show-inheritance:
VectorEnvWrapper
~~~~~~~~~~~~~~~~

View File

@ -25,6 +25,11 @@ DQN Family
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.BranchingDQNPolicy
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.C51Policy
:members:
:undoc-members:

View File

@ -12,6 +12,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.BranchingDQNPolicy` `Branching DQN <https://arxiv.org/pdf/1711.08946.pdf>`_
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN <https://arxiv.org/pdf/1710.02298.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_

View File

@ -4,3 +4,10 @@
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)
![](results/sac/BipedalHardcore.png)
# BipedalWalker-BDQ
- To demonstrate the cpabilities of the BDQ to scale up to big discrete action spaces, we run it on a discretized version of the BipedalWalker-v3 environment, where the number of possible actions in each dimension is 25, for a total of 25^4 = 390 625 possible actions. A usaual DQN architecture would use 25^4 output neurons for the Q-network, thus scaling exponentially with the number of action space dimensions, while the Branching architecture scales linearly and uses only 25*4 output neurons.
![](results/bdq/BipedalWalker.png)

View File

@ -0,0 +1,163 @@
import argparse
import datetime
import os
import pprint
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv
from tianshou.policy import BranchingDQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import BranchingNet
def get_args():
parser = argparse.ArgumentParser()
# task
parser.add_argument("--task", type=str, default="BipedalWalker-v3")
# network architecture
parser.add_argument(
"--common-hidden-sizes", type=int, nargs="*", default=[512, 256]
)
parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[128])
parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[128])
parser.add_argument("--action-per-branch", type=int, default=25)
# training hyperparameters
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.)
parser.add_argument("--eps-train", type=float, default=0.73)
parser.add_argument("--eps-decay", type=float, default=5e-6)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--target-update-freq", type=int, default=1000)
parser.add_argument("--epoch", type=int, default=1000)
parser.add_argument("--step-per-epoch", type=int, default=80000)
parser.add_argument("--step-per-collect", type=int, default=16)
parser.add_argument("--update-per-step", type=float, default=0.0625)
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument("--training-num", type=int, default=20)
parser.add_argument("--test-num", type=int, default=10)
# other
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
return parser.parse_args()
def test_bdq(args=get_args()):
env = gym.make(args.task)
env = ContinuousToDiscrete(env, args.action_per_branch)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.num_branches = args.action_shape if isinstance(args.action_shape,
int) else args.action_shape[0]
print("Observations shape:", args.state_shape)
print("Num branches:", args.num_branches)
print("Actions per branch:", args.action_per_branch)
# train_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = SubprocVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.training_num)
]
)
# test_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
test_envs = SubprocVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.test_num)
]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = BranchingNet(
args.state_shape,
args.num_branches,
args.action_per_branch,
args.common_hidden_sizes,
args.value_hidden_sizes,
args.action_hidden_sizes,
device=args.device,
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = BranchingDQNPolicy(
net, optim, args.gamma, target_update_freq=args.target_update_freq
)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_path = os.path.join(args.logdir, "bdq", args.task, current_time)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
return mean_rewards >= getattr(env.spec.reward_threshold)
def train_fn(epoch, env_step): # exp decay
eps = max(args.eps_train * (1 - args.eps_decay)**env_step, args.eps_test)
policy.set_eps(eps)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
update_per_step=args.update_per_step,
# stop_fn=stop_fn,
train_fn=train_fn,
test_fn=test_fn,
save_best_fn=save_best_fn,
logger=logger
)
# assert stop_fn(result["best_reward"])
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == "__main__":
test_bdq(get_args())

Binary file not shown.

After

Width:  |  Height:  |  Size: 87 KiB

150
test/discrete/test_bdq.py Normal file
View File

@ -0,0 +1,150 @@
import argparse
import pprint
import gym
import numpy as np
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ContinuousToDiscrete, DummyVectorEnv
from tianshou.policy import BranchingDQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import BranchingNet
def get_args():
parser = argparse.ArgumentParser()
# task
parser.add_argument("--task", type=str, default="Pendulum-v1")
parser.add_argument('--reward-threshold', type=float, default=None)
# network architecture
parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[64])
parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[64])
parser.add_argument("--action-per-branch", type=int, default=40)
# training hyperparameters
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.01)
parser.add_argument("--eps-train", type=float, default=0.76)
parser.add_argument("--eps-decay", type=float, default=1e-4)
parser.add_argument("--buffer-size", type=int, default=20000)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--gamma", type=float, default=0.9)
parser.add_argument("--target-update-freq", type=int, default=200)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--step-per-epoch", type=int, default=80000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
args = parser.parse_known_args()[0]
return args
def test_bdq(args=get_args()):
env = gym.make(args.task)
env = ContinuousToDiscrete(env, args.action_per_branch)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.num_branches = env.action_space.shape[0]
if args.reward_threshold is None:
default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250}
args.reward_threshold = default_reward_threshold.get(
args.task, env.spec.reward_threshold
)
print("Observations shape:", args.state_shape)
print("Num branches:", args.num_branches)
print("Actions per branch:", args.action_per_branch)
train_envs = DummyVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.training_num)
]
)
test_envs = DummyVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.test_num)
]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = BranchingNet(
args.state_shape,
args.num_branches,
args.action_per_branch,
args.common_hidden_sizes,
args.value_hidden_sizes,
args.action_hidden_sizes,
device=args.device,
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = BranchingDQNPolicy(
net, optim, args.gamma, target_update_freq=args.target_update_freq
)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, args.training_num),
exploration_noise=True
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
def train_fn(epoch, env_step): # exp decay
eps = max(args.eps_train * (1 - args.eps_decay)**env_step, args.eps_test)
policy.set_eps(eps)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
update_per_step=args.update_per_step,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
)
# assert stop_fn(result["best_reward"])
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == "__main__":
test_bdq(get_args())

View File

@ -1,7 +1,7 @@
import pickle
from copy import deepcopy
from numbers import Number
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, no_type_check
import h5py
import numpy as np
@ -10,6 +10,7 @@ import torch
from tianshou.data.batch import Batch, _parse_value
@no_type_check
def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor): # most often case
@ -30,6 +31,7 @@ def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
return np.asanyarray(x)
@no_type_check
def to_torch(
x: Any,
dtype: Optional[torch.dtype] = None,
@ -39,14 +41,14 @@ def to_torch(
if isinstance(x, np.ndarray) and issubclass(
x.dtype.type, (np.bool_, np.number)
): # most often case
x = torch.from_numpy(x).to(device) # type: ignore
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
return x
elif isinstance(x, torch.Tensor): # second often case
if dtype is not None:
x = x.type(dtype)
return x.to(device) # type: ignore
return x.to(device)
elif isinstance(x, (np.number, np.bool_, Number)):
return to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, (dict, Batch)):
@ -59,6 +61,7 @@ def to_torch(
raise TypeError(f"object {x} cannot be converted to torch.")
@no_type_check
def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]:
"""Return an object without np.ndarray.

View File

@ -1,5 +1,6 @@
"""Env package."""
from tianshou.env.gym_wrappers import ContinuousToDiscrete
from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper
from tianshou.env.venvs import (
BaseVectorEnv,
@ -23,4 +24,5 @@ __all__ = [
"VectorEnvWrapper",
"VectorEnvNormObs",
"PettingZooEnv",
"ContinuousToDiscrete",
]

28
tianshou/env/gym_wrappers.py vendored Normal file
View File

@ -0,0 +1,28 @@
import gym
import numpy as np
class ContinuousToDiscrete(gym.ActionWrapper):
"""Gym environment wrapper to take discrete action in a continuous environment.
:param gym.Env env: gym environment with continuous action space.
:param int action_per_branch: number of discrete actions in each dimension
of the action space.
"""
def __init__(self, env: gym.Env, action_per_branch: int) -> None:
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.Box)
low, high = env.action_space.low, env.action_space.high
num_branches = env.action_space.shape[0]
self.action_space = gym.spaces.MultiDiscrete(
[action_per_branch] * num_branches
)
mesh = []
for lo, hi in zip(low, high):
mesh.append(np.linspace(lo, hi, action_per_branch))
self.mesh = np.array(mesh)
def action(self, act: np.ndarray) -> np.ndarray:
# modify act
return np.array([self.mesh[i][a] for i, a in enumerate(act)])

View File

@ -4,6 +4,7 @@
from tianshou.policy.base import BasePolicy
from tianshou.policy.random import RandomPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.bdq import BranchingDQNPolicy
from tianshou.policy.modelfree.c51 import C51Policy
from tianshou.policy.modelfree.rainbow import RainbowPolicy
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
@ -34,6 +35,7 @@ __all__ = [
"BasePolicy",
"RandomPolicy",
"DQNPolicy",
"BranchingDQNPolicy",
"C51Policy",
"RainbowPolicy",
"QRDQNPolicy",

View File

@ -339,7 +339,7 @@ class BasePolicy(ABC, nn.Module):
assert np.isclose(gae_lambda, 1.0)
v_s_ = np.zeros_like(rew)
else:
v_s_ = to_numpy(v_s_.flatten()) # type: ignore
v_s_ = to_numpy(v_s_.flatten())
v_s_ = v_s_ * BasePolicy.value_mask(buffer, indices)
v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten())

View File

@ -54,11 +54,11 @@ class ImitationPolicy(BasePolicy):
if self.action_type == "continuous": # regression
act = self(batch).act
act_target = to_torch(batch.act, dtype=torch.float32, device=act.device)
loss = F.mse_loss(act, act_target) # type: ignore
loss = F.mse_loss(act, act_target)
elif self.action_type == "discrete": # classification
act = F.log_softmax(self(batch).logits, dim=-1)
act_target = to_torch(batch.act, dtype=torch.long, device=act.device)
loss = F.nll_loss(act, act_target) # type: ignore
loss = F.nll_loss(act, act_target)
loss.backward()
self.optim.step()
return {"loss": loss.item()}

View File

@ -104,9 +104,7 @@ class BCQPolicy(BasePolicy):
"""Compute action over the given batch data."""
# There is "obs" in the Batch
# obs_group: several groups. Each group has a state.
obs_group: torch.Tensor = to_torch( # type: ignore
batch.obs, device=self.device
)
obs_group: torch.Tensor = to_torch(batch.obs, device=self.device)
act_group = []
for obs in obs_group:
# now obs is (state_dim)
@ -132,9 +130,7 @@ class BCQPolicy(BasePolicy):
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# batch: obs, act, rew, done, obs_next. (numpy array)
# (batch_size, state_dim)
batch: Batch = to_torch( # type: ignore
batch, dtype=torch.float, device=self.device
)
batch: Batch = to_torch(batch, dtype=torch.float, device=self.device)
obs, act = batch.obs, batch.act
batch_size = obs.shape[0]

View File

@ -160,9 +160,7 @@ class CQLPolicy(SACPolicy):
return batch
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
batch: Batch = to_torch( # type: ignore
batch, dtype=torch.float, device=self.device,
)
batch: Batch = to_torch(batch, dtype=torch.float, device=self.device)
obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next
batch_size = obs.shape[0]
@ -208,12 +206,12 @@ class CQLPolicy(SACPolicy):
random_actions = torch.FloatTensor(
batch_size * self.num_repeat_actions, act.shape[-1]
).uniform_(-self.min_action, self.max_action).to(self.device)
tmp_obs = obs.unsqueeze(1) \
.repeat(1, self.num_repeat_actions, 1) \
.view(batch_size * self.num_repeat_actions, obs.shape[-1])
tmp_obs_next = obs_next.unsqueeze(1) \
.repeat(1, self.num_repeat_actions, 1) \
.view(batch_size * self.num_repeat_actions, obs.shape[-1])
obs_len = len(obs.shape)
repeat_size = [1, self.num_repeat_actions] + [1] * (obs_len - 1)
view_size = [batch_size * self.num_repeat_actions] + list(obs.shape[1:])
tmp_obs = obs.unsqueeze(1).repeat(*repeat_size).view(*view_size)
tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size)
# tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim)
current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs)

View File

@ -113,10 +113,7 @@ class DiscreteBCQPolicy(DQNPolicy):
current_q = result.q_value[np.arange(len(target_q)), batch.act]
act = to_torch(batch.act, dtype=torch.long, device=target_q.device)
q_loss = F.smooth_l1_loss(current_q, target_q)
i_loss = F.nll_loss(
F.log_softmax(imitation_logits, dim=-1),
act # type: ignore
)
i_loss = F.nll_loss(F.log_softmax(imitation_logits, dim=-1), act)
reg_loss = imitation_logits.pow(2).mean()
loss = q_loss + i_loss + self._weight_reg * reg_loss

View File

@ -108,9 +108,9 @@ class GAILPolicy(PPOPolicy):
return super().process_fn(batch, buffer, indices)
def disc(self, batch: Batch) -> torch.Tensor:
obs = to_torch(batch.obs, device=self.disc_net.device) # type: ignore
act = to_torch(batch.act, device=self.disc_net.device) # type: ignore
return self.disc_net(torch.cat([obs, act], dim=1)) # type: ignore
obs = to_torch(batch.obs, device=self.disc_net.device)
act = to_torch(batch.act, device=self.disc_net.device)
return self.disc_net(torch.cat([obs, act], dim=1))
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any

View File

@ -105,7 +105,7 @@ class ICMPolicy(BasePolicy):
self.optim.zero_grad()
act_hat = batch.policy.act_hat
act = to_torch(batch.act, dtype=torch.long, device=act_hat.device)
inverse_loss = F.cross_entropy(act_hat, act).mean() # type: ignore
inverse_loss = F.cross_entropy(act_hat, act).mean()
forward_loss = batch.policy.mse_loss.mean()
loss = (
(1 - self.forward_loss_weight) * inverse_loss +

View File

@ -0,0 +1,144 @@
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as
from tianshou.policy import DQNPolicy
from tianshou.utils.net.common import BranchingNet
class BranchingDQNPolicy(DQNPolicy):
"""Implementation of the Branching dual Q network arXiv:1711.08946.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int estimation_step: the number of steps to look ahead. Default to 1.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param bool is_double: use double network. Default to True.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
model: BranchingNet,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
is_double: bool = True,
**kwargs: Any,
) -> None:
super().__init__(
model, optim, discount_factor, estimation_step, target_update_freq,
reward_normalization, is_double
)
assert estimation_step == 1, "N-step bigger than one is not supported by BDQ"
self.max_action_num = model.action_per_branch
self.num_branches = model.num_branches
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
result = self(batch, input="obs_next")
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
target_q = self(batch, model="model_old", input="obs_next").logits
else:
target_q = result.logits
if self._is_double:
act = np.expand_dims(self(batch, input="obs_next").act, -1)
act = to_torch(act, dtype=torch.long, device=target_q.device)
else:
act = target_q.max(-1).indices.unsqueeze(-1)
return torch.gather(target_q, -1, act).squeeze()
def _compute_return(
self,
batch: Batch,
buffer: ReplayBuffer,
indice: np.ndarray,
gamma: float = 0.99,
) -> Batch:
rew = batch.rew
with torch.no_grad():
target_q_torch = self._target_q(buffer, indice) # (bsz, ?)
target_q = to_numpy(target_q_torch)
end_flag = buffer.done.copy()
end_flag[buffer.unfinished_index()] = True
end_flag = end_flag[indice]
mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q
_target_q = rew + gamma * mean_target_q * (1 - end_flag)
target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1)
target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1)
batch.returns = to_torch_as(target_q, target_q_torch)
if hasattr(batch, "weight"): # prio buffer update
batch.weight = to_torch_as(batch.weight, target_q_torch)
return batch
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
"""Compute the 1-step return for BDQ targets."""
return self._compute_return(batch, buffer, indices)
def forward(
self,
batch: Batch,
state: Optional[Union[Dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
**kwargs: Any,
) -> Batch:
model = getattr(self, model)
obs = batch[input]
obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info)
act = to_numpy(logits.max(dim=-1)[1])
return Batch(logits=logits, act=act, state=hidden)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device)
q = self(batch).logits
act_mask = torch.zeros_like(q)
act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1)
act_q = q * act_mask
returns = batch.returns
returns = returns * act_mask
td_error = returns - act_q
loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean()
batch.weight = td_error.sum(-1).sum(-1) # prio-buffer
loss.backward()
self.optim.step()
self._iter += 1
return {"loss": loss.item()}
def exploration_noise(
self,
act: Union[np.ndarray, Batch],
batch: Batch,
) -> Union[np.ndarray, Batch]:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
rand_act = np.random.randint(
low=0, high=self.max_action_num, size=(bsz, act.shape[-1])
)
if hasattr(batch.obs, "mask"):
rand_act += batch.obs.mask
act[rand_mask] = rand_act[rand_mask]
return act

View File

@ -138,8 +138,6 @@ class DQNPolicy(BasePolicy):
...
)
:param float eps: in [0, 1], for epsilon-greedy exploration method.
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
* ``act`` the action.
@ -177,8 +175,11 @@ class DQNPolicy(BasePolicy):
self._iter += 1
return {"loss": loss.item()}
def exploration_noise(self, act: Union[np.ndarray, Batch],
batch: Batch) -> Union[np.ndarray, Batch]:
def exploration_noise(
self,
act: Union[np.ndarray, Batch],
batch: Batch,
) -> Union[np.ndarray, Batch]:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps

View File

@ -363,3 +363,93 @@ class EnsembleLinear(nn.Module):
if self.bias is not None:
x = x + self.bias
return x
class BranchingNet(nn.Module):
"""Branching dual Q network.
Network for the BranchingDQNPolicy, it uses a common network module, a value module
and action "branches" one for each dimension.It allows for a linear scaling
of Q-value the output w.r.t. the number of dimensions in the action space.
For more info please refer to: arXiv:1711.08946.
:param state_shape: int or a sequence of int of the shape of state.
:param action_shape: int or a sequence of int of the shape of action.
:param action_peer_branch: int or a sequence of int of the number of actions in
each dimension.
:param common_hidden_sizes: shape of the common MLP network passed in as a list.
:param value_hidden_sizes: shape of the value MLP network passed in as a list.
:param action_hidden_sizes: shape of the action MLP network passed in as a list.
:param norm_layer: use which normalization before activation, e.g.,
``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization.
You can also pass a list of normalization modules with the same length
of hidden_sizes, to use different normalization module in different
layers. Default to no normalization.
:param activation: which activation to use after each layer, can be both
the same activation for all layers if passed in nn.Module, or different
activation for different Modules if passed in a list. Default to
nn.ReLU.
:param device: specify the device when the network actually runs. Default
to "cpu".
:param bool softmax: whether to apply a softmax layer over the last layer's
output.
"""
def __init__(
self,
state_shape: Union[int, Sequence[int]],
num_branches: int = 0,
action_per_branch: int = 2,
common_hidden_sizes: List[int] = [],
value_hidden_sizes: List[int] = [],
action_hidden_sizes: List[int] = [],
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = nn.ReLU,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
self.num_branches = num_branches
self.action_per_branch = action_per_branch
# common network
common_input_dim = int(np.prod(state_shape))
common_output_dim = 0
self.common = MLP(
common_input_dim, common_output_dim, common_hidden_sizes, norm_layer,
activation, device
)
# value network
value_input_dim = common_hidden_sizes[-1]
value_output_dim = 1
self.value = MLP(
value_input_dim, value_output_dim, value_hidden_sizes, norm_layer,
activation, device
)
# action branching network
action_input_dim = common_hidden_sizes[-1]
action_output_dim = action_per_branch
self.branches = nn.ModuleList(
[
MLP(
action_input_dim, action_output_dim, action_hidden_sizes,
norm_layer, activation, device
) for _ in range(self.num_branches)
]
)
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: obs -> model -> logits."""
common_out = self.common(obs)
value_out = self.value(common_out)
value_out = torch.unsqueeze(value_out, 1)
action_out = []
for b in self.branches:
action_out.append(b(common_out))
action_scores = torch.stack(action_out, 1)
action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True)
logits = value_out + action_scores
return logits, state