Add BranchingDQN for large discrete action spaces (#618)
This commit is contained in:
parent
a03f19af72
commit
53e6b0408d
@ -12,6 +12,7 @@
|
||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
|
||||
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
|
||||
- [Branching DQN](https://arxiv.org/pdf/1711.08946.pdf)
|
||||
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
||||
- [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf)
|
||||
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
||||
|
@ -49,6 +49,14 @@ RayVectorEnv
|
||||
Wrapper
|
||||
-------
|
||||
|
||||
ContinuousToDiscrete
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: tianshou.env.ContinuousToDiscrete
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
VectorEnvWrapper
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -25,6 +25,11 @@ DQN Family
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.BranchingDQNPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.C51Policy
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -12,6 +12,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
|
||||
* :class:`~tianshou.policy.BranchingDQNPolicy` `Branching DQN <https://arxiv.org/pdf/1711.08946.pdf>`_
|
||||
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||
* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN <https://arxiv.org/pdf/1710.02298.pdf>`_
|
||||
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
||||
|
@ -4,3 +4,10 @@
|
||||
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)
|
||||
|
||||

|
||||
|
||||
|
||||
# BipedalWalker-BDQ
|
||||
|
||||
- To demonstrate the cpabilities of the BDQ to scale up to big discrete action spaces, we run it on a discretized version of the BipedalWalker-v3 environment, where the number of possible actions in each dimension is 25, for a total of 25^4 = 390 625 possible actions. A usaual DQN architecture would use 25^4 output neurons for the Q-network, thus scaling exponentially with the number of action space dimensions, while the Branching architecture scales linearly and uses only 25*4 output neurons.
|
||||
|
||||

|
163
examples/box2d/bipedal_bdq.py
Normal file
163
examples/box2d/bipedal_bdq.py
Normal file
@ -0,0 +1,163 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv
|
||||
from tianshou.policy import BranchingDQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import BranchingNet
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# task
|
||||
parser.add_argument("--task", type=str, default="BipedalWalker-v3")
|
||||
# network architecture
|
||||
parser.add_argument(
|
||||
"--common-hidden-sizes", type=int, nargs="*", default=[512, 256]
|
||||
)
|
||||
parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[128])
|
||||
parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[128])
|
||||
parser.add_argument("--action-per-branch", type=int, default=25)
|
||||
# training hyperparameters
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--eps-test", type=float, default=0.)
|
||||
parser.add_argument("--eps-train", type=float, default=0.73)
|
||||
parser.add_argument("--eps-decay", type=float, default=5e-6)
|
||||
parser.add_argument("--buffer-size", type=int, default=100000)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--target-update-freq", type=int, default=1000)
|
||||
parser.add_argument("--epoch", type=int, default=1000)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=80000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=16)
|
||||
parser.add_argument("--update-per-step", type=float, default=0.0625)
|
||||
parser.add_argument("--batch-size", type=int, default=512)
|
||||
parser.add_argument("--training-num", type=int, default=20)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
# other
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_bdq(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
env = ContinuousToDiscrete(env, args.action_per_branch)
|
||||
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.num_branches = args.action_shape if isinstance(args.action_shape,
|
||||
int) else args.action_shape[0]
|
||||
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Num branches:", args.num_branches)
|
||||
print("Actions per branch:", args.action_per_branch)
|
||||
|
||||
# train_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = SubprocVectorEnv(
|
||||
[
|
||||
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
|
||||
for _ in range(args.training_num)
|
||||
]
|
||||
)
|
||||
# test_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[
|
||||
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
|
||||
for _ in range(args.test_num)
|
||||
]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = BranchingNet(
|
||||
args.state_shape,
|
||||
args.num_branches,
|
||||
args.action_per_branch,
|
||||
args.common_hidden_sizes,
|
||||
args.value_hidden_sizes,
|
||||
args.action_hidden_sizes,
|
||||
device=args.device,
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = BranchingDQNPolicy(
|
||||
net, optim, args.gamma, target_update_freq=args.target_update_freq
|
||||
)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy,
|
||||
train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=False)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
log_path = os.path.join(args.logdir, "bdq", args.task, current_time)
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= getattr(env.spec.reward_threshold)
|
||||
|
||||
def train_fn(epoch, env_step): # exp decay
|
||||
eps = max(args.eps_train * (1 - args.eps_decay)**env_step, args.eps_test)
|
||||
policy.set_eps(eps)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
update_per_step=args.update_per_step,
|
||||
# stop_fn=stop_fn,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
# assert stop_fn(result["best_reward"])
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bdq(get_args())
|
BIN
examples/box2d/results/bdq/BipedalWalker.png
Normal file
BIN
examples/box2d/results/bdq/BipedalWalker.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 87 KiB |
150
test/discrete/test_bdq.py
Normal file
150
test/discrete/test_bdq.py
Normal file
@ -0,0 +1,150 @@
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import ContinuousToDiscrete, DummyVectorEnv
|
||||
from tianshou.policy import BranchingDQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils.net.common import BranchingNet
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# task
|
||||
parser.add_argument("--task", type=str, default="Pendulum-v1")
|
||||
parser.add_argument('--reward-threshold', type=float, default=None)
|
||||
# network architecture
|
||||
parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[64, 64])
|
||||
parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[64])
|
||||
parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[64])
|
||||
parser.add_argument("--action-per-branch", type=int, default=40)
|
||||
# training hyperparameters
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.01)
|
||||
parser.add_argument("--eps-train", type=float, default=0.76)
|
||||
parser.add_argument("--eps-decay", type=float, default=1e-4)
|
||||
parser.add_argument("--buffer-size", type=int, default=20000)
|
||||
parser.add_argument("--lr", type=float, default=1e-3)
|
||||
parser.add_argument("--gamma", type=float, default=0.9)
|
||||
parser.add_argument("--target-update-freq", type=int, default=200)
|
||||
parser.add_argument("--epoch", type=int, default=10)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=80000)
|
||||
parser.add_argument("--step-per-collect", type=int, default=10)
|
||||
parser.add_argument("--update-per-step", type=float, default=0.1)
|
||||
parser.add_argument("--batch-size", type=int, default=128)
|
||||
parser.add_argument("--training-num", type=int, default=10)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_bdq(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
env = ContinuousToDiscrete(env, args.action_per_branch)
|
||||
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.num_branches = env.action_space.shape[0]
|
||||
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task, env.spec.reward_threshold
|
||||
)
|
||||
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Num branches:", args.num_branches)
|
||||
print("Actions per branch:", args.action_per_branch)
|
||||
|
||||
train_envs = DummyVectorEnv(
|
||||
[
|
||||
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
|
||||
for _ in range(args.training_num)
|
||||
]
|
||||
)
|
||||
test_envs = DummyVectorEnv(
|
||||
[
|
||||
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
|
||||
for _ in range(args.test_num)
|
||||
]
|
||||
)
|
||||
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = BranchingNet(
|
||||
args.state_shape,
|
||||
args.num_branches,
|
||||
args.action_per_branch,
|
||||
args.common_hidden_sizes,
|
||||
args.value_hidden_sizes,
|
||||
args.action_hidden_sizes,
|
||||
device=args.device,
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = BranchingDQNPolicy(
|
||||
net, optim, args.gamma, target_update_freq=args.target_update_freq
|
||||
)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy,
|
||||
train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, args.training_num),
|
||||
exploration_noise=True
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=False)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
|
||||
def train_fn(epoch, env_step): # exp decay
|
||||
eps = max(args.eps_train * (1 - args.eps_decay)**env_step, args.eps_test)
|
||||
policy.set_eps(eps)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
update_per_step=args.update_per_step,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
)
|
||||
|
||||
# assert stop_fn(result["best_reward"])
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bdq(get_args())
|
@ -1,7 +1,7 @@
|
||||
import pickle
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union, no_type_check
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
from tianshou.data.batch import Batch, _parse_value
|
||||
|
||||
|
||||
@no_type_check
|
||||
def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
|
||||
"""Return an object without torch.Tensor."""
|
||||
if isinstance(x, torch.Tensor): # most often case
|
||||
@ -30,6 +31,7 @@ def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
|
||||
return np.asanyarray(x)
|
||||
|
||||
|
||||
@no_type_check
|
||||
def to_torch(
|
||||
x: Any,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
@ -39,14 +41,14 @@ def to_torch(
|
||||
if isinstance(x, np.ndarray) and issubclass(
|
||||
x.dtype.type, (np.bool_, np.number)
|
||||
): # most often case
|
||||
x = torch.from_numpy(x).to(device) # type: ignore
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
return x
|
||||
elif isinstance(x, torch.Tensor): # second often case
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
return x.to(device) # type: ignore
|
||||
return x.to(device)
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
return to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, (dict, Batch)):
|
||||
@ -59,6 +61,7 @@ def to_torch(
|
||||
raise TypeError(f"object {x} cannot be converted to torch.")
|
||||
|
||||
|
||||
@no_type_check
|
||||
def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]:
|
||||
"""Return an object without np.ndarray.
|
||||
|
||||
|
2
tianshou/env/__init__.py
vendored
2
tianshou/env/__init__.py
vendored
@ -1,5 +1,6 @@
|
||||
"""Env package."""
|
||||
|
||||
from tianshou.env.gym_wrappers import ContinuousToDiscrete
|
||||
from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper
|
||||
from tianshou.env.venvs import (
|
||||
BaseVectorEnv,
|
||||
@ -23,4 +24,5 @@ __all__ = [
|
||||
"VectorEnvWrapper",
|
||||
"VectorEnvNormObs",
|
||||
"PettingZooEnv",
|
||||
"ContinuousToDiscrete",
|
||||
]
|
||||
|
28
tianshou/env/gym_wrappers.py
vendored
Normal file
28
tianshou/env/gym_wrappers.py
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ContinuousToDiscrete(gym.ActionWrapper):
|
||||
"""Gym environment wrapper to take discrete action in a continuous environment.
|
||||
|
||||
:param gym.Env env: gym environment with continuous action space.
|
||||
:param int action_per_branch: number of discrete actions in each dimension
|
||||
of the action space.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, action_per_branch: int) -> None:
|
||||
super().__init__(env)
|
||||
assert isinstance(env.action_space, gym.spaces.Box)
|
||||
low, high = env.action_space.low, env.action_space.high
|
||||
num_branches = env.action_space.shape[0]
|
||||
self.action_space = gym.spaces.MultiDiscrete(
|
||||
[action_per_branch] * num_branches
|
||||
)
|
||||
mesh = []
|
||||
for lo, hi in zip(low, high):
|
||||
mesh.append(np.linspace(lo, hi, action_per_branch))
|
||||
self.mesh = np.array(mesh)
|
||||
|
||||
def action(self, act: np.ndarray) -> np.ndarray:
|
||||
# modify act
|
||||
return np.array([self.mesh[i][a] for i, a in enumerate(act)])
|
@ -4,6 +4,7 @@
|
||||
from tianshou.policy.base import BasePolicy
|
||||
from tianshou.policy.random import RandomPolicy
|
||||
from tianshou.policy.modelfree.dqn import DQNPolicy
|
||||
from tianshou.policy.modelfree.bdq import BranchingDQNPolicy
|
||||
from tianshou.policy.modelfree.c51 import C51Policy
|
||||
from tianshou.policy.modelfree.rainbow import RainbowPolicy
|
||||
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
||||
@ -34,6 +35,7 @@ __all__ = [
|
||||
"BasePolicy",
|
||||
"RandomPolicy",
|
||||
"DQNPolicy",
|
||||
"BranchingDQNPolicy",
|
||||
"C51Policy",
|
||||
"RainbowPolicy",
|
||||
"QRDQNPolicy",
|
||||
|
@ -339,7 +339,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
assert np.isclose(gae_lambda, 1.0)
|
||||
v_s_ = np.zeros_like(rew)
|
||||
else:
|
||||
v_s_ = to_numpy(v_s_.flatten()) # type: ignore
|
||||
v_s_ = to_numpy(v_s_.flatten())
|
||||
v_s_ = v_s_ * BasePolicy.value_mask(buffer, indices)
|
||||
v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten())
|
||||
|
||||
|
@ -54,11 +54,11 @@ class ImitationPolicy(BasePolicy):
|
||||
if self.action_type == "continuous": # regression
|
||||
act = self(batch).act
|
||||
act_target = to_torch(batch.act, dtype=torch.float32, device=act.device)
|
||||
loss = F.mse_loss(act, act_target) # type: ignore
|
||||
loss = F.mse_loss(act, act_target)
|
||||
elif self.action_type == "discrete": # classification
|
||||
act = F.log_softmax(self(batch).logits, dim=-1)
|
||||
act_target = to_torch(batch.act, dtype=torch.long, device=act.device)
|
||||
loss = F.nll_loss(act, act_target) # type: ignore
|
||||
loss = F.nll_loss(act, act_target)
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
return {"loss": loss.item()}
|
||||
|
@ -104,9 +104,7 @@ class BCQPolicy(BasePolicy):
|
||||
"""Compute action over the given batch data."""
|
||||
# There is "obs" in the Batch
|
||||
# obs_group: several groups. Each group has a state.
|
||||
obs_group: torch.Tensor = to_torch( # type: ignore
|
||||
batch.obs, device=self.device
|
||||
)
|
||||
obs_group: torch.Tensor = to_torch(batch.obs, device=self.device)
|
||||
act_group = []
|
||||
for obs in obs_group:
|
||||
# now obs is (state_dim)
|
||||
@ -132,9 +130,7 @@ class BCQPolicy(BasePolicy):
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
# batch: obs, act, rew, done, obs_next. (numpy array)
|
||||
# (batch_size, state_dim)
|
||||
batch: Batch = to_torch( # type: ignore
|
||||
batch, dtype=torch.float, device=self.device
|
||||
)
|
||||
batch: Batch = to_torch(batch, dtype=torch.float, device=self.device)
|
||||
obs, act = batch.obs, batch.act
|
||||
batch_size = obs.shape[0]
|
||||
|
||||
|
@ -160,9 +160,7 @@ class CQLPolicy(SACPolicy):
|
||||
return batch
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
batch: Batch = to_torch( # type: ignore
|
||||
batch, dtype=torch.float, device=self.device,
|
||||
)
|
||||
batch: Batch = to_torch(batch, dtype=torch.float, device=self.device)
|
||||
obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next
|
||||
batch_size = obs.shape[0]
|
||||
|
||||
@ -208,12 +206,12 @@ class CQLPolicy(SACPolicy):
|
||||
random_actions = torch.FloatTensor(
|
||||
batch_size * self.num_repeat_actions, act.shape[-1]
|
||||
).uniform_(-self.min_action, self.max_action).to(self.device)
|
||||
tmp_obs = obs.unsqueeze(1) \
|
||||
.repeat(1, self.num_repeat_actions, 1) \
|
||||
.view(batch_size * self.num_repeat_actions, obs.shape[-1])
|
||||
tmp_obs_next = obs_next.unsqueeze(1) \
|
||||
.repeat(1, self.num_repeat_actions, 1) \
|
||||
.view(batch_size * self.num_repeat_actions, obs.shape[-1])
|
||||
|
||||
obs_len = len(obs.shape)
|
||||
repeat_size = [1, self.num_repeat_actions] + [1] * (obs_len - 1)
|
||||
view_size = [batch_size * self.num_repeat_actions] + list(obs.shape[1:])
|
||||
tmp_obs = obs.unsqueeze(1).repeat(*repeat_size).view(*view_size)
|
||||
tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size)
|
||||
# tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim)
|
||||
|
||||
current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs)
|
||||
|
@ -113,10 +113,7 @@ class DiscreteBCQPolicy(DQNPolicy):
|
||||
current_q = result.q_value[np.arange(len(target_q)), batch.act]
|
||||
act = to_torch(batch.act, dtype=torch.long, device=target_q.device)
|
||||
q_loss = F.smooth_l1_loss(current_q, target_q)
|
||||
i_loss = F.nll_loss(
|
||||
F.log_softmax(imitation_logits, dim=-1),
|
||||
act # type: ignore
|
||||
)
|
||||
i_loss = F.nll_loss(F.log_softmax(imitation_logits, dim=-1), act)
|
||||
reg_loss = imitation_logits.pow(2).mean()
|
||||
loss = q_loss + i_loss + self._weight_reg * reg_loss
|
||||
|
||||
|
@ -108,9 +108,9 @@ class GAILPolicy(PPOPolicy):
|
||||
return super().process_fn(batch, buffer, indices)
|
||||
|
||||
def disc(self, batch: Batch) -> torch.Tensor:
|
||||
obs = to_torch(batch.obs, device=self.disc_net.device) # type: ignore
|
||||
act = to_torch(batch.act, device=self.disc_net.device) # type: ignore
|
||||
return self.disc_net(torch.cat([obs, act], dim=1)) # type: ignore
|
||||
obs = to_torch(batch.obs, device=self.disc_net.device)
|
||||
act = to_torch(batch.act, device=self.disc_net.device)
|
||||
return self.disc_net(torch.cat([obs, act], dim=1))
|
||||
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
|
@ -105,7 +105,7 @@ class ICMPolicy(BasePolicy):
|
||||
self.optim.zero_grad()
|
||||
act_hat = batch.policy.act_hat
|
||||
act = to_torch(batch.act, dtype=torch.long, device=act_hat.device)
|
||||
inverse_loss = F.cross_entropy(act_hat, act).mean() # type: ignore
|
||||
inverse_loss = F.cross_entropy(act_hat, act).mean()
|
||||
forward_loss = batch.policy.mse_loss.mean()
|
||||
loss = (
|
||||
(1 - self.forward_loss_weight) * inverse_loss +
|
||||
|
144
tianshou/policy/modelfree/bdq.py
Normal file
144
tianshou/policy/modelfree/bdq.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.utils.net.common import BranchingNet
|
||||
|
||||
|
||||
class BranchingDQNPolicy(DQNPolicy):
|
||||
"""Implementation of the Branching dual Q network arXiv:1711.08946.
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param int estimation_step: the number of steps to look ahead. Default to 1.
|
||||
:param int target_update_freq: the target network update frequency (0 if
|
||||
you do not use the target network). Default to 0.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param bool is_double: use double network. Default to True.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BranchingNet,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: float = 0.99,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
is_double: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model, optim, discount_factor, estimation_step, target_update_freq,
|
||||
reward_normalization, is_double
|
||||
)
|
||||
assert estimation_step == 1, "N-step bigger than one is not supported by BDQ"
|
||||
self.max_action_num = model.action_per_branch
|
||||
self.num_branches = model.num_branches
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
||||
result = self(batch, input="obs_next")
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
target_q = self(batch, model="model_old", input="obs_next").logits
|
||||
else:
|
||||
target_q = result.logits
|
||||
if self._is_double:
|
||||
act = np.expand_dims(self(batch, input="obs_next").act, -1)
|
||||
act = to_torch(act, dtype=torch.long, device=target_q.device)
|
||||
else:
|
||||
act = target_q.max(-1).indices.unsqueeze(-1)
|
||||
return torch.gather(target_q, -1, act).squeeze()
|
||||
|
||||
def _compute_return(
|
||||
self,
|
||||
batch: Batch,
|
||||
buffer: ReplayBuffer,
|
||||
indice: np.ndarray,
|
||||
gamma: float = 0.99,
|
||||
) -> Batch:
|
||||
rew = batch.rew
|
||||
with torch.no_grad():
|
||||
target_q_torch = self._target_q(buffer, indice) # (bsz, ?)
|
||||
target_q = to_numpy(target_q_torch)
|
||||
end_flag = buffer.done.copy()
|
||||
end_flag[buffer.unfinished_index()] = True
|
||||
end_flag = end_flag[indice]
|
||||
mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q
|
||||
_target_q = rew + gamma * mean_target_q * (1 - end_flag)
|
||||
target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1)
|
||||
target_q = np.repeat(target_q[..., None], self.max_action_num, axis=-1)
|
||||
|
||||
batch.returns = to_torch_as(target_q, target_q_torch)
|
||||
if hasattr(batch, "weight"): # prio buffer update
|
||||
batch.weight = to_torch_as(batch.weight, target_q_torch)
|
||||
return batch
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||
) -> Batch:
|
||||
"""Compute the 1-step return for BDQ targets."""
|
||||
return self._compute_return(batch, buffer, indices)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[Dict, Batch, np.ndarray]] = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, hidden = model(obs_next, state=state, info=batch.info)
|
||||
act = to_numpy(logits.max(dim=-1)[1])
|
||||
return Batch(logits=logits, act=act, state=hidden)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
weight = batch.pop("weight", 1.0)
|
||||
act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device)
|
||||
q = self(batch).logits
|
||||
act_mask = torch.zeros_like(q)
|
||||
act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1)
|
||||
act_q = q * act_mask
|
||||
returns = batch.returns
|
||||
returns = returns * act_mask
|
||||
td_error = returns - act_q
|
||||
loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean()
|
||||
batch.weight = td_error.sum(-1).sum(-1) # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._iter += 1
|
||||
return {"loss": loss.item()}
|
||||
|
||||
def exploration_noise(
|
||||
self,
|
||||
act: Union[np.ndarray, Batch],
|
||||
batch: Batch,
|
||||
) -> Union[np.ndarray, Batch]:
|
||||
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
|
||||
bsz = len(act)
|
||||
rand_mask = np.random.rand(bsz) < self.eps
|
||||
rand_act = np.random.randint(
|
||||
low=0, high=self.max_action_num, size=(bsz, act.shape[-1])
|
||||
)
|
||||
if hasattr(batch.obs, "mask"):
|
||||
rand_act += batch.obs.mask
|
||||
act[rand_mask] = rand_act[rand_mask]
|
||||
return act
|
@ -138,8 +138,6 @@ class DQNPolicy(BasePolicy):
|
||||
...
|
||||
)
|
||||
|
||||
:param float eps: in [0, 1], for epsilon-greedy exploration method.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
|
||||
|
||||
* ``act`` the action.
|
||||
@ -177,8 +175,11 @@ class DQNPolicy(BasePolicy):
|
||||
self._iter += 1
|
||||
return {"loss": loss.item()}
|
||||
|
||||
def exploration_noise(self, act: Union[np.ndarray, Batch],
|
||||
batch: Batch) -> Union[np.ndarray, Batch]:
|
||||
def exploration_noise(
|
||||
self,
|
||||
act: Union[np.ndarray, Batch],
|
||||
batch: Batch,
|
||||
) -> Union[np.ndarray, Batch]:
|
||||
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
|
||||
bsz = len(act)
|
||||
rand_mask = np.random.rand(bsz) < self.eps
|
||||
|
@ -363,3 +363,93 @@ class EnsembleLinear(nn.Module):
|
||||
if self.bias is not None:
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
|
||||
class BranchingNet(nn.Module):
|
||||
"""Branching dual Q network.
|
||||
|
||||
Network for the BranchingDQNPolicy, it uses a common network module, a value module
|
||||
and action "branches" one for each dimension.It allows for a linear scaling
|
||||
of Q-value the output w.r.t. the number of dimensions in the action space.
|
||||
For more info please refer to: arXiv:1711.08946.
|
||||
:param state_shape: int or a sequence of int of the shape of state.
|
||||
:param action_shape: int or a sequence of int of the shape of action.
|
||||
:param action_peer_branch: int or a sequence of int of the number of actions in
|
||||
each dimension.
|
||||
:param common_hidden_sizes: shape of the common MLP network passed in as a list.
|
||||
:param value_hidden_sizes: shape of the value MLP network passed in as a list.
|
||||
:param action_hidden_sizes: shape of the action MLP network passed in as a list.
|
||||
:param norm_layer: use which normalization before activation, e.g.,
|
||||
``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization.
|
||||
You can also pass a list of normalization modules with the same length
|
||||
of hidden_sizes, to use different normalization module in different
|
||||
layers. Default to no normalization.
|
||||
:param activation: which activation to use after each layer, can be both
|
||||
the same activation for all layers if passed in nn.Module, or different
|
||||
activation for different Modules if passed in a list. Default to
|
||||
nn.ReLU.
|
||||
:param device: specify the device when the network actually runs. Default
|
||||
to "cpu".
|
||||
:param bool softmax: whether to apply a softmax layer over the last layer's
|
||||
output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_shape: Union[int, Sequence[int]],
|
||||
num_branches: int = 0,
|
||||
action_per_branch: int = 2,
|
||||
common_hidden_sizes: List[int] = [],
|
||||
value_hidden_sizes: List[int] = [],
|
||||
action_hidden_sizes: List[int] = [],
|
||||
norm_layer: Optional[ModuleType] = None,
|
||||
activation: Optional[ModuleType] = nn.ReLU,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.num_branches = num_branches
|
||||
self.action_per_branch = action_per_branch
|
||||
# common network
|
||||
common_input_dim = int(np.prod(state_shape))
|
||||
common_output_dim = 0
|
||||
self.common = MLP(
|
||||
common_input_dim, common_output_dim, common_hidden_sizes, norm_layer,
|
||||
activation, device
|
||||
)
|
||||
# value network
|
||||
value_input_dim = common_hidden_sizes[-1]
|
||||
value_output_dim = 1
|
||||
self.value = MLP(
|
||||
value_input_dim, value_output_dim, value_hidden_sizes, norm_layer,
|
||||
activation, device
|
||||
)
|
||||
# action branching network
|
||||
action_input_dim = common_hidden_sizes[-1]
|
||||
action_output_dim = action_per_branch
|
||||
self.branches = nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
action_input_dim, action_output_dim, action_hidden_sizes,
|
||||
norm_layer, activation, device
|
||||
) for _ in range(self.num_branches)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: Union[np.ndarray, torch.Tensor],
|
||||
state: Any = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Mapping: obs -> model -> logits."""
|
||||
common_out = self.common(obs)
|
||||
value_out = self.value(common_out)
|
||||
value_out = torch.unsqueeze(value_out, 1)
|
||||
action_out = []
|
||||
for b in self.branches:
|
||||
action_out.append(b(common_out))
|
||||
action_scores = torch.stack(action_out, 1)
|
||||
action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True)
|
||||
logits = value_out + action_scores
|
||||
return logits, state
|
||||
|
Loading…
x
Reference in New Issue
Block a user