Add Rainbow DQN (#386)

- add RainbowPolicy
- add `set_beta` method in prio_buffer
- add NoisyLinear in utils/network
This commit is contained in:
Yi Su 2021-08-29 08:34:59 -07:00 committed by GitHub
parent d161059c3d
commit 291be08d43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 636 additions and 9 deletions

View File

@ -22,6 +22,7 @@
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
- [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf)

View File

@ -30,6 +30,11 @@ DQN Family
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.RainbowPolicy
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.QRDQNPolicy
:members:
:undoc-members:

View File

@ -13,6 +13,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN <https://arxiv.org/pdf/1707.02298.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function <https://arxiv.org/pdf/1911.02140.pdf>`_

View File

@ -82,6 +82,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| SeaquestNoFrameskip-v4 | 10775 | ![](results/fqf/Seaquest_rew.png) | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 2482 | ![](results/fqf/SpaceInvaders_rew.png) | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` |
# Rainbow (single run)
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 21 | ![](results/rainbow/Pong_rew.png) | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch-size 64` |
| BreakoutNoFrameskip-v4 | 684.6 | ![](results/rainbow/Breakout_rew.png) | `python3 atari_rainbow.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
| EnduroNoFrameskip-v4 | 1625.9 | ![](results/rainbow/Enduro_rew.png) | `python3 atari_rainbow.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 16192.5 | ![](results/rainbow/Qbert_rew.png) | `python3 atari_rainbow.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |
# BCQ
To running BCQ algorithm on Atari, you need to do the following things:

View File

@ -2,6 +2,7 @@ import torch
import numpy as np
from torch import nn
from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.utils.net.discrete import NoisyLinear
class DQN(nn.Module):
@ -81,6 +82,65 @@ class C51(DQN):
return x, state
class Rainbow(DQN):
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
noisy_std: float = 0.5,
device: Union[str, int, torch.device] = "cpu",
is_dueling: bool = True,
is_noisy: bool = True,
) -> None:
super().__init__(c, h, w, action_shape, device, features_only=True)
self.action_num = np.prod(action_shape)
self.num_atoms = num_atoms
def linear(x, y):
if is_noisy:
return NoisyLinear(x, y, noisy_std)
else:
return nn.Linear(x, y)
self.Q = nn.Sequential(
linear(self.output_dim, 512), nn.ReLU(inplace=True),
linear(512, self.action_num * self.num_atoms))
self._is_dueling = is_dueling
if self._is_dueling:
self.V = nn.Sequential(
linear(self.output_dim, 512), nn.ReLU(inplace=True),
linear(512, self.num_atoms))
self.output_dim = self.action_num * self.num_atoms
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
q = self.Q(x)
q = q.view(-1, self.action_num, self.num_atoms)
if self._is_dueling:
v = self.V(x)
v = v.view(-1, 1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
else:
logits = q
y = logits.softmax(dim=2)
return y, state
class QRDQN(DQN):
"""Reference: Distributional Reinforcement Learning with Quantile \
Regression.

View File

@ -0,0 +1,204 @@
import os
import torch
import pprint
import datetime
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import RainbowPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from atari_network import Rainbow
from atari_wrapper import wrap_deepmind
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0000625)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--noisy-std', type=float, default=0.1)
parser.add_argument('--no-dueling', action='store_true', default=False)
parser.add_argument('--no-noisy', action='store_true', default=False)
parser.add_argument('--no-priority', action='store_true', default=False)
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--beta-final', type=float, default=1.)
parser.add_argument('--beta-anneal-step', type=int, default=5000000)
parser.add_argument('--no-weight-norm', action='store_true', default=False)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)
def test_rainbow(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = Rainbow(*args.state_shape, args.action_shape,
args.num_atoms, args.noisy_std, args.device,
is_dueling=not args.no_dueling,
is_noisy=not args.no_noisy)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = RainbowPolicy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
if args.no_priority:
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack)
else:
buffer = PrioritizedVectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack, alpha=args.alpha,
beta=args.beta, weight_norm=not args.no_weight_norm)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(
args.logdir, args.task, 'rainbow',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if not args.no_priority:
if env_step <= args.beta_anneal_step:
beta = args.beta - env_step / args.beta_anneal_step * \
(args.beta - args.beta_final)
else:
beta = args.beta_final
buffer.set_beta(beta)
logger.write('train/beta', env_step, beta)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = PrioritizedVectorReplayBuffer(
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack, alpha=args.alpha,
beta=args.beta)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
if args.watch:
watch()
exit(0)
# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result)
watch()
if __name__ == '__main__':
test_rainbow(get_args())

Binary file not shown.

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 192 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 233 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 224 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 198 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 226 KiB

View File

@ -193,7 +193,7 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
mask = np.isin(np.arange(buf2.maxsize), indices)
assert np.all(weight[mask] == weight[mask][0])
assert np.all(weight[~mask] == weight[~mask][0])
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] < 1
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
def test_update():

View File

@ -54,6 +54,8 @@ def get_args():
def test_qrdqn(args=get_args()):
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)

View File

@ -50,7 +50,7 @@ def test_discrete_cql(args=get_args()):
# envs
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
env.spec.reward_threshold = 185 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(

View File

@ -0,0 +1,198 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import RainbowPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import NoisyLinear
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--noisy-std', type=float, default=0.1)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=8000)
parser.add_argument('--step-per-collect', type=int, default=8)
parser.add_argument('--update-per-step', type=float, default=0.125)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay',
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--beta-final', type=float, default=1.)
parser.add_argument('--resume', action="store_true")
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0]
return args
def test_rainbow(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
def noisy_linear(x, y):
return NoisyLinear(x, y, args.noisy_std)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
softmax=True, num_atoms=args.num_atoms,
dueling_param=({"linear_layer": noisy_linear},
{"linear_layer": noisy_linear}))
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = RainbowPolicy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs),
alpha=args.alpha, beta=args.beta, weight_norm=True)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'rainbow')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer, save_interval=args.save_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(epoch, env_step):
# eps annealing, just a demo
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)
# beta annealing, just a demo
if args.prioritized_replay:
if env_step <= 10000:
beta = args.beta
elif env_step <= 50000:
beta = args.beta - (env_step - 10000) / \
40000 * (args.beta - args.beta_final)
else:
beta = args.beta_final
buf.set_beta(beta)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
pickle.dump(train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
else:
print("Fail to restore buffer.")
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_rainbow_resume(args=get_args()):
args.resume = True
test_rainbow(args)
def test_prainbow(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_rainbow(args)
if __name__ == '__main__':
test_rainbow(get_args())

View File

@ -10,13 +10,22 @@ class PrioritizedReplayBuffer(ReplayBuffer):
:param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient.
:param bool weight_norm: whether to normalize returned weights with the maximum
weight value within the batch. Default to True.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None:
def __init__(
self,
size: int,
alpha: float,
beta: float,
weight_norm: bool = True,
**kwargs: Any
) -> None:
# will raise KeyError in PrioritizedVectorReplayBuffer
# super().__init__(size, **kwargs)
ReplayBuffer.__init__(self, size, **kwargs)
@ -27,6 +36,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self.weight = SegmentTree(size)
self.__eps = np.finfo(np.float32).eps.item()
self.options.update(alpha=alpha, beta=beta)
self._weight_norm = weight_norm
def init_weight(self, index: Union[int, np.ndarray]) -> None:
self.weight[index] = self._max_prio ** self._alpha
@ -83,5 +93,10 @@ class PrioritizedReplayBuffer(ReplayBuffer):
else:
indices = index
batch = super().__getitem__(indices)
batch.weight = self.get_weight(indices)
weight = self.get_weight(indices)
# ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154
batch.weight = weight / np.max(weight) if self._weight_norm else weight
return batch
def set_beta(self, beta: float) -> None:
self._beta = beta

View File

@ -55,3 +55,7 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)
]
super().__init__(buffer_list)
def set_beta(self, beta: float) -> None:
for buffer in self.buffers:
buffer.set_beta(beta)

View File

@ -2,6 +2,7 @@ from tianshou.policy.base import BasePolicy
from tianshou.policy.random import RandomPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.c51 import C51Policy
from tianshou.policy.modelfree.rainbow import RainbowPolicy
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
from tianshou.policy.modelfree.iqn import IQNPolicy
from tianshou.policy.modelfree.fqf import FQFPolicy
@ -27,6 +28,7 @@ __all__ = [
"RandomPolicy",
"DQNPolicy",
"C51Policy",
"RainbowPolicy",
"QRDQNPolicy",
"IQNPolicy",
"FQFPolicy",

View File

@ -0,0 +1,37 @@
from typing import Any, Dict
from tianshou.policy import C51Policy
from tianshou.data import Batch
from tianshou.utils.net.discrete import sample_noise
class RainbowPolicy(C51Policy):
"""Implementation of Rainbow DQN. arXiv:1710.02298.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int num_atoms: the number of atoms in the support set of the
value distribution. Default to 51.
:param float v_min: the value of the smallest atom in the support set.
Default to -10.0.
:param float v_max: the value of the largest atom in the support set.
Default to 10.0.
:param int estimation_step: the number of steps to look ahead. Default to 1.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
.. seealso::
Please refer to :class:`~tianshou.policy.C51Policy` for more detailed
explanation.
"""
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
sample_noise(self.model)
if self._target and sample_noise(self.model_old):
self.model_old.train() # so that NoisyLinear takes effect
return super().learn(batch, **kwargs)

View File

@ -11,10 +11,11 @@ def miniblock(
output_size: int = 0,
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = None,
linear_layer: Type[nn.Linear] = nn.Linear,
) -> List[nn.Module]:
"""Construct a miniblock with given input/output-size, norm layer and \
activation."""
layers: List[nn.Module] = [nn.Linear(input_size, output_size)]
layers: List[nn.Module] = [linear_layer(input_size, output_size)]
if norm_layer is not None:
layers += [norm_layer(output_size)] # type: ignore
if activation is not None:
@ -42,6 +43,8 @@ class MLP(nn.Module):
the same actvition for all layers if passed in nn.Module, or different
activation for different Modules if passed in a list. Default to
nn.ReLU.
:param device: which device to create this model on. Default to None.
:param linear_layer: use this module as linear layer. Default to nn.Linear.
"""
def __init__(
@ -52,6 +55,7 @@ class MLP(nn.Module):
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
device: Optional[Union[str, int, torch.device]] = None,
linear_layer: Type[nn.Linear] = nn.Linear,
) -> None:
super().__init__()
self.device = device
@ -78,9 +82,9 @@ class MLP(nn.Module):
for in_dim, out_dim, norm, activ in zip(
hidden_sizes[:-1], hidden_sizes[1:],
norm_layer_list, activation_list):
model += miniblock(in_dim, out_dim, norm, activ)
model += miniblock(in_dim, out_dim, norm, activ, linear_layer)
if output_dim > 0:
model += [nn.Linear(hidden_sizes[-1], output_dim)]
model += [linear_layer(hidden_sizes[-1], output_dim)]
self.output_dim = output_dim or hidden_sizes[-1]
self.model = nn.Sequential(*model)
@ -168,10 +172,10 @@ class Net(nn.Module):
q_output_dim, v_output_dim = action_dim, num_atoms
q_kwargs: Dict[str, Any] = {
**q_kwargs, "input_dim": self.output_dim,
"output_dim": q_output_dim}
"output_dim": q_output_dim, "device": self.device}
v_kwargs: Dict[str, Any] = {
**v_kwargs, "input_dim": self.output_dim,
"output_dim": v_output_dim}
"output_dim": v_output_dim, "device": self.device}
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
self.output_dim = self.Q.output_dim

View File

@ -307,3 +307,83 @@ class FullQuantileFunction(ImplicitQuantileNetwork):
with torch.no_grad():
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
return (quantiles, fractions, quantiles_tau), h
class NoisyLinear(nn.Module):
"""Implementation of Noisy Networks. arXiv:1706.10295.
:param int in_features: the number of input features.
:param int out_features: the number of output features.
:param float noisy_std: initial standard deviation of noisy linear layers.
.. note::
Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
/fqf_iqn_qrdqn/network.py .
"""
def __init__(
self, in_features: int, out_features: int, noisy_std: float = 0.5
) -> None:
super().__init__()
# Learnable parameters.
self.mu_W = nn.Parameter(
torch.FloatTensor(out_features, in_features))
self.sigma_W = nn.Parameter(
torch.FloatTensor(out_features, in_features))
self.mu_bias = nn.Parameter(torch.FloatTensor(out_features))
self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features))
# Factorized noise parameters.
self.register_buffer('eps_p', torch.FloatTensor(in_features))
self.register_buffer('eps_q', torch.FloatTensor(out_features))
self.in_features = in_features
self.out_features = out_features
self.sigma = noisy_std
self.reset()
self.sample()
def reset(self) -> None:
bound = 1 / np.sqrt(self.in_features)
self.mu_W.data.uniform_(-bound, bound)
self.mu_bias.data.uniform_(-bound, bound)
self.sigma_W.data.fill_(self.sigma / np.sqrt(self.in_features))
self.sigma_bias.data.fill_(self.sigma / np.sqrt(self.in_features))
def f(self, x: torch.Tensor) -> torch.Tensor:
x = torch.randn(x.size(0), device=x.device)
return x.sign().mul_(x.abs().sqrt_())
def sample(self) -> None:
self.eps_p.copy_(self.f(self.eps_p)) # type: ignore
self.eps_q.copy_(self.f(self.eps_q)) # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
weight = self.mu_W + self.sigma_W * (
self.eps_q.ger(self.eps_p) # type: ignore
)
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore
else:
weight = self.mu_W
bias = self.mu_bias
return F.linear(x, weight, bias)
def sample_noise(model: nn.Module) -> bool:
"""Sample the random noises of NoisyLinear modules in the model.
:param model: a PyTorch module which may have NoisyLinear submodules.
:returns: True if model has at least one NoisyLinear submodule;
otherwise, False.
"""
done = False
for m in model.modules():
if isinstance(m, NoisyLinear):
m.sample()
done = True
return done