Add Rainbow DQN (#386)
- add RainbowPolicy - add `set_beta` method in prio_buffer - add NoisyLinear in utils/network
@ -22,6 +22,7 @@
|
|||||||
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
|
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
|
||||||
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
|
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
|
||||||
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
||||||
|
- [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf)
|
||||||
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
||||||
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
|
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
|
||||||
- [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf)
|
- [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf)
|
||||||
|
@ -30,6 +30,11 @@ DQN Family
|
|||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.policy.RainbowPolicy
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autoclass:: tianshou.policy.QRDQNPolicy
|
.. autoclass:: tianshou.policy.QRDQNPolicy
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
|
@ -13,6 +13,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
|
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
|
||||||
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||||
|
* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN <https://arxiv.org/pdf/1707.02298.pdf>`_
|
||||||
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
||||||
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
|
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
|
||||||
* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function <https://arxiv.org/pdf/1911.02140.pdf>`_
|
* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function <https://arxiv.org/pdf/1911.02140.pdf>`_
|
||||||
|
@ -82,6 +82,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
|||||||
| SeaquestNoFrameskip-v4 | 10775 |  | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` |
|
| SeaquestNoFrameskip-v4 | 10775 |  | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` |
|
||||||
| SpaceInvadersNoFrameskip-v4 | 2482 |  | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` |
|
| SpaceInvadersNoFrameskip-v4 | 2482 |  | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||||
|
|
||||||
|
# Rainbow (single run)
|
||||||
|
|
||||||
|
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||||
|
|
||||||
|
| task | best reward | reward curve | parameters |
|
||||||
|
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||||
|
| PongNoFrameskip-v4 | 21 |  | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch-size 64` |
|
||||||
|
| BreakoutNoFrameskip-v4 | 684.6 |  | `python3 atari_rainbow.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
|
||||||
|
| EnduroNoFrameskip-v4 | 1625.9 |  | `python3 atari_rainbow.py --task "EnduroNoFrameskip-v4"` |
|
||||||
|
| QbertNoFrameskip-v4 | 16192.5 |  | `python3 atari_rainbow.py --task "QbertNoFrameskip-v4"` |
|
||||||
|
| MsPacmanNoFrameskip-v4 | 3101 |  | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
|
||||||
|
| SeaquestNoFrameskip-v4 | 2126 |  | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
|
||||||
|
| SpaceInvadersNoFrameskip-v4 | 1794.5 |  | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||||
|
|
||||||
# BCQ
|
# BCQ
|
||||||
|
|
||||||
To running BCQ algorithm on Atari, you need to do the following things:
|
To running BCQ algorithm on Atari, you need to do the following things:
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||||
|
from tianshou.utils.net.discrete import NoisyLinear
|
||||||
|
|
||||||
|
|
||||||
class DQN(nn.Module):
|
class DQN(nn.Module):
|
||||||
@ -81,6 +82,65 @@ class C51(DQN):
|
|||||||
return x, state
|
return x, state
|
||||||
|
|
||||||
|
|
||||||
|
class Rainbow(DQN):
|
||||||
|
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
|
||||||
|
|
||||||
|
For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
c: int,
|
||||||
|
h: int,
|
||||||
|
w: int,
|
||||||
|
action_shape: Sequence[int],
|
||||||
|
num_atoms: int = 51,
|
||||||
|
noisy_std: float = 0.5,
|
||||||
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
|
is_dueling: bool = True,
|
||||||
|
is_noisy: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(c, h, w, action_shape, device, features_only=True)
|
||||||
|
self.action_num = np.prod(action_shape)
|
||||||
|
self.num_atoms = num_atoms
|
||||||
|
|
||||||
|
def linear(x, y):
|
||||||
|
if is_noisy:
|
||||||
|
return NoisyLinear(x, y, noisy_std)
|
||||||
|
else:
|
||||||
|
return nn.Linear(x, y)
|
||||||
|
|
||||||
|
self.Q = nn.Sequential(
|
||||||
|
linear(self.output_dim, 512), nn.ReLU(inplace=True),
|
||||||
|
linear(512, self.action_num * self.num_atoms))
|
||||||
|
self._is_dueling = is_dueling
|
||||||
|
if self._is_dueling:
|
||||||
|
self.V = nn.Sequential(
|
||||||
|
linear(self.output_dim, 512), nn.ReLU(inplace=True),
|
||||||
|
linear(512, self.num_atoms))
|
||||||
|
self.output_dim = self.action_num * self.num_atoms
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Union[np.ndarray, torch.Tensor],
|
||||||
|
state: Optional[Any] = None,
|
||||||
|
info: Dict[str, Any] = {},
|
||||||
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
|
r"""Mapping: x -> Z(x, \*)."""
|
||||||
|
x, state = super().forward(x)
|
||||||
|
q = self.Q(x)
|
||||||
|
q = q.view(-1, self.action_num, self.num_atoms)
|
||||||
|
if self._is_dueling:
|
||||||
|
v = self.V(x)
|
||||||
|
v = v.view(-1, 1, self.num_atoms)
|
||||||
|
logits = q - q.mean(dim=1, keepdim=True) + v
|
||||||
|
else:
|
||||||
|
logits = q
|
||||||
|
y = logits.softmax(dim=2)
|
||||||
|
return y, state
|
||||||
|
|
||||||
|
|
||||||
class QRDQN(DQN):
|
class QRDQN(DQN):
|
||||||
"""Reference: Distributional Reinforcement Learning with Quantile \
|
"""Reference: Distributional Reinforcement Learning with Quantile \
|
||||||
Regression.
|
Regression.
|
||||||
|
204
examples/atari/atari_rainbow.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import pprint
|
||||||
|
import datetime
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.policy import RainbowPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||||
|
|
||||||
|
from atari_network import Rainbow
|
||||||
|
from atari_wrapper import wrap_deepmind
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=1.)
|
||||||
|
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||||
|
parser.add_argument('--lr', type=float, default=0.0000625)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
|
parser.add_argument('--num-atoms', type=int, default=51)
|
||||||
|
parser.add_argument('--v-min', type=float, default=-10.)
|
||||||
|
parser.add_argument('--v-max', type=float, default=10.)
|
||||||
|
parser.add_argument('--noisy-std', type=float, default=0.1)
|
||||||
|
parser.add_argument('--no-dueling', action='store_true', default=False)
|
||||||
|
parser.add_argument('--no-noisy', action='store_true', default=False)
|
||||||
|
parser.add_argument('--no-priority', action='store_true', default=False)
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.5)
|
||||||
|
parser.add_argument('--beta', type=float, default=0.4)
|
||||||
|
parser.add_argument('--beta-final', type=float, default=1.)
|
||||||
|
parser.add_argument('--beta-anneal-step', type=int, default=5000000)
|
||||||
|
parser.add_argument('--no-weight-norm', action='store_true', default=False)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||||
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=32)
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
parser.add_argument('--frames-stack', type=int, default=4)
|
||||||
|
parser.add_argument('--resume-path', type=str, default=None)
|
||||||
|
parser.add_argument('--watch', default=False, action='store_true',
|
||||||
|
help='watch the play of pre-trained policy only')
|
||||||
|
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def make_atari_env(args):
|
||||||
|
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||||
|
|
||||||
|
|
||||||
|
def make_atari_env_watch(args):
|
||||||
|
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||||
|
episode_life=False, clip_rewards=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rainbow(args=get_args()):
|
||||||
|
env = make_atari_env(args)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# should be N_FRAMES x H x W
|
||||||
|
print("Observations shape:", args.state_shape)
|
||||||
|
print("Actions shape:", args.action_shape)
|
||||||
|
# make environments
|
||||||
|
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
|
||||||
|
for _ in range(args.training_num)])
|
||||||
|
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||||
|
for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# define model
|
||||||
|
net = Rainbow(*args.state_shape, args.action_shape,
|
||||||
|
args.num_atoms, args.noisy_std, args.device,
|
||||||
|
is_dueling=not args.no_dueling,
|
||||||
|
is_noisy=not args.no_noisy)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
# define policy
|
||||||
|
policy = RainbowPolicy(
|
||||||
|
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
|
||||||
|
args.n_step, target_update_freq=args.target_update_freq
|
||||||
|
).to(args.device)
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||||
|
# when you have enough RAM
|
||||||
|
if args.no_priority:
|
||||||
|
buffer = VectorReplayBuffer(
|
||||||
|
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
|
||||||
|
save_only_last_obs=True, stack_num=args.frames_stack)
|
||||||
|
else:
|
||||||
|
buffer = PrioritizedVectorReplayBuffer(
|
||||||
|
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
|
||||||
|
save_only_last_obs=True, stack_num=args.frames_stack, alpha=args.alpha,
|
||||||
|
beta=args.beta, weight_norm=not args.no_weight_norm)
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(
|
||||||
|
args.logdir, args.task, 'rainbow',
|
||||||
|
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
if env.spec.reward_threshold:
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
elif 'Pong' in args.task:
|
||||||
|
return mean_rewards >= 20
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def train_fn(epoch, env_step):
|
||||||
|
# nature DQN setting, linear decay in the first 1M steps
|
||||||
|
if env_step <= 1e6:
|
||||||
|
eps = args.eps_train - env_step / 1e6 * \
|
||||||
|
(args.eps_train - args.eps_train_final)
|
||||||
|
else:
|
||||||
|
eps = args.eps_train_final
|
||||||
|
policy.set_eps(eps)
|
||||||
|
logger.write('train/eps', env_step, eps)
|
||||||
|
if not args.no_priority:
|
||||||
|
if env_step <= args.beta_anneal_step:
|
||||||
|
beta = args.beta - env_step / args.beta_anneal_step * \
|
||||||
|
(args.beta - args.beta_final)
|
||||||
|
else:
|
||||||
|
beta = args.beta_final
|
||||||
|
buffer.set_beta(beta)
|
||||||
|
logger.write('train/beta', env_step, beta)
|
||||||
|
|
||||||
|
def test_fn(epoch, env_step):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
# watch agent's performance
|
||||||
|
def watch():
|
||||||
|
print("Setup test envs ...")
|
||||||
|
policy.eval()
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
if args.save_buffer_name:
|
||||||
|
print(f"Generate buffer with size {args.buffer_size}")
|
||||||
|
buffer = PrioritizedVectorReplayBuffer(
|
||||||
|
args.buffer_size, buffer_num=len(test_envs),
|
||||||
|
ignore_obs_next=True, save_only_last_obs=True,
|
||||||
|
stack_num=args.frames_stack, alpha=args.alpha,
|
||||||
|
beta=args.beta)
|
||||||
|
collector = Collector(policy, test_envs, buffer,
|
||||||
|
exploration_noise=True)
|
||||||
|
result = collector.collect(n_step=args.buffer_size)
|
||||||
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
|
else:
|
||||||
|
print("Testing agent ...")
|
||||||
|
test_collector.reset()
|
||||||
|
result = test_collector.collect(n_episode=args.test_num,
|
||||||
|
render=args.render)
|
||||||
|
rew = result["rews"].mean()
|
||||||
|
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||||
|
|
||||||
|
if args.watch:
|
||||||
|
watch()
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
# test train_collector and start filling replay buffer
|
||||||
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
|
|
||||||
|
pprint.pprint(result)
|
||||||
|
watch()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_rainbow(get_args())
|
BIN
examples/atari/results/rainbow/Breakout_rew.png
Normal file
After Width: | Height: | Size: 210 KiB |
BIN
examples/atari/results/rainbow/Enduro_rew.png
Normal file
After Width: | Height: | Size: 192 KiB |
BIN
examples/atari/results/rainbow/MsPacman_rew.png
Normal file
After Width: | Height: | Size: 233 KiB |
BIN
examples/atari/results/rainbow/Pong_rew.png
Normal file
After Width: | Height: | Size: 144 KiB |
BIN
examples/atari/results/rainbow/Qbert_rew.png
Normal file
After Width: | Height: | Size: 224 KiB |
BIN
examples/atari/results/rainbow/Seaquest_rew.png
Normal file
After Width: | Height: | Size: 198 KiB |
BIN
examples/atari/results/rainbow/SpaceInvaders_rew.png
Normal file
After Width: | Height: | Size: 226 KiB |
@ -193,7 +193,7 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
|||||||
mask = np.isin(np.arange(buf2.maxsize), indices)
|
mask = np.isin(np.arange(buf2.maxsize), indices)
|
||||||
assert np.all(weight[mask] == weight[mask][0])
|
assert np.all(weight[mask] == weight[mask][0])
|
||||||
assert np.all(weight[~mask] == weight[~mask][0])
|
assert np.all(weight[~mask] == weight[~mask][0])
|
||||||
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] < 1
|
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
|
||||||
|
|
||||||
|
|
||||||
def test_update():
|
def test_update():
|
||||||
|
@ -54,6 +54,8 @@ def get_args():
|
|||||||
|
|
||||||
def test_qrdqn(args=get_args()):
|
def test_qrdqn(args=get_args()):
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
|
if args.task == 'CartPole-v0':
|
||||||
|
env.spec.reward_threshold = 190 # lower the goal
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
|
@ -50,7 +50,7 @@ def test_discrete_cql(args=get_args()):
|
|||||||
# envs
|
# envs
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
if args.task == 'CartPole-v0':
|
if args.task == 'CartPole-v0':
|
||||||
env.spec.reward_threshold = 190 # lower the goal
|
env.spec.reward_threshold = 185 # lower the goal
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
test_envs = DummyVectorEnv(
|
test_envs = DummyVectorEnv(
|
||||||
|
198
test/discrete/test_rainbow.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
import os
|
||||||
|
import gym
|
||||||
|
import torch
|
||||||
|
import pickle
|
||||||
|
import pprint
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.policy import RainbowPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.utils.net.discrete import NoisyLinear
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
|
parser.add_argument('--num-atoms', type=int, default=51)
|
||||||
|
parser.add_argument('--v-min', type=float, default=-10.)
|
||||||
|
parser.add_argument('--v-max', type=float, default=10.)
|
||||||
|
parser.add_argument('--noisy-std', type=float, default=0.1)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=8000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=8)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.125)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
|
nargs='*', default=[128, 128, 128, 128])
|
||||||
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument('--prioritized-replay',
|
||||||
|
action="store_true", default=False)
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.6)
|
||||||
|
parser.add_argument('--beta', type=float, default=0.4)
|
||||||
|
parser.add_argument('--beta-final', type=float, default=1.)
|
||||||
|
parser.add_argument('--resume', action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
parser.add_argument("--save-interval", type=int, default=4)
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_rainbow(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
train_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
|
||||||
|
def noisy_linear(x, y):
|
||||||
|
return NoisyLinear(x, y, args.noisy_std)
|
||||||
|
|
||||||
|
net = Net(args.state_shape, args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||||
|
softmax=True, num_atoms=args.num_atoms,
|
||||||
|
dueling_param=({"linear_layer": noisy_linear},
|
||||||
|
{"linear_layer": noisy_linear}))
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
policy = RainbowPolicy(
|
||||||
|
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
|
||||||
|
args.n_step, target_update_freq=args.target_update_freq
|
||||||
|
).to(args.device)
|
||||||
|
# buffer
|
||||||
|
if args.prioritized_replay:
|
||||||
|
buf = PrioritizedVectorReplayBuffer(
|
||||||
|
args.buffer_size, buffer_num=len(train_envs),
|
||||||
|
alpha=args.alpha, beta=args.beta, weight_norm=True)
|
||||||
|
else:
|
||||||
|
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
# policy.set_eps(1)
|
||||||
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'rainbow')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
def train_fn(epoch, env_step):
|
||||||
|
# eps annealing, just a demo
|
||||||
|
if env_step <= 10000:
|
||||||
|
policy.set_eps(args.eps_train)
|
||||||
|
elif env_step <= 50000:
|
||||||
|
eps = args.eps_train - (env_step - 10000) / \
|
||||||
|
40000 * (0.9 * args.eps_train)
|
||||||
|
policy.set_eps(eps)
|
||||||
|
else:
|
||||||
|
policy.set_eps(0.1 * args.eps_train)
|
||||||
|
# beta annealing, just a demo
|
||||||
|
if args.prioritized_replay:
|
||||||
|
if env_step <= 10000:
|
||||||
|
beta = args.beta
|
||||||
|
elif env_step <= 50000:
|
||||||
|
beta = args.beta - (env_step - 10000) / \
|
||||||
|
40000 * (args.beta - args.beta_final)
|
||||||
|
else:
|
||||||
|
beta = args.beta_final
|
||||||
|
buf.set_beta(beta)
|
||||||
|
|
||||||
|
def test_fn(epoch, env_step):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||||
|
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||||
|
torch.save({
|
||||||
|
'model': policy.state_dict(),
|
||||||
|
'optim': optim.state_dict(),
|
||||||
|
}, os.path.join(log_path, 'checkpoint.pth'))
|
||||||
|
pickle.dump(train_collector.buffer,
|
||||||
|
open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))
|
||||||
|
|
||||||
|
if args.resume:
|
||||||
|
# load from existing checkpoint
|
||||||
|
print(f"Loading agent under {log_path}")
|
||||||
|
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||||
|
if os.path.exists(ckpt_path):
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||||
|
policy.load_state_dict(checkpoint['model'])
|
||||||
|
policy.optim.load_state_dict(checkpoint['optim'])
|
||||||
|
print("Successfully restore policy and optim.")
|
||||||
|
else:
|
||||||
|
print("Fail to restore policy and optim.")
|
||||||
|
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
|
||||||
|
if os.path.exists(buffer_path):
|
||||||
|
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
|
||||||
|
print("Successfully restore buffer.")
|
||||||
|
else:
|
||||||
|
print("Fail to restore buffer.")
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
|
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
|
||||||
|
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
|
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
policy.eval()
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_rainbow_resume(args=get_args()):
|
||||||
|
args.resume = True
|
||||||
|
test_rainbow(args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prainbow(args=get_args()):
|
||||||
|
args.prioritized_replay = True
|
||||||
|
args.gamma = .95
|
||||||
|
args.seed = 1
|
||||||
|
test_rainbow(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_rainbow(get_args())
|
@ -10,13 +10,22 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
|
|
||||||
:param float alpha: the prioritization exponent.
|
:param float alpha: the prioritization exponent.
|
||||||
:param float beta: the importance sample soft coefficient.
|
:param float beta: the importance sample soft coefficient.
|
||||||
|
:param bool weight_norm: whether to normalize returned weights with the maximum
|
||||||
|
weight value within the batch. Default to True.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
alpha: float,
|
||||||
|
beta: float,
|
||||||
|
weight_norm: bool = True,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
# will raise KeyError in PrioritizedVectorReplayBuffer
|
# will raise KeyError in PrioritizedVectorReplayBuffer
|
||||||
# super().__init__(size, **kwargs)
|
# super().__init__(size, **kwargs)
|
||||||
ReplayBuffer.__init__(self, size, **kwargs)
|
ReplayBuffer.__init__(self, size, **kwargs)
|
||||||
@ -27,6 +36,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
self.weight = SegmentTree(size)
|
self.weight = SegmentTree(size)
|
||||||
self.__eps = np.finfo(np.float32).eps.item()
|
self.__eps = np.finfo(np.float32).eps.item()
|
||||||
self.options.update(alpha=alpha, beta=beta)
|
self.options.update(alpha=alpha, beta=beta)
|
||||||
|
self._weight_norm = weight_norm
|
||||||
|
|
||||||
def init_weight(self, index: Union[int, np.ndarray]) -> None:
|
def init_weight(self, index: Union[int, np.ndarray]) -> None:
|
||||||
self.weight[index] = self._max_prio ** self._alpha
|
self.weight[index] = self._max_prio ** self._alpha
|
||||||
@ -83,5 +93,10 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
else:
|
else:
|
||||||
indices = index
|
indices = index
|
||||||
batch = super().__getitem__(indices)
|
batch = super().__getitem__(indices)
|
||||||
batch.weight = self.get_weight(indices)
|
weight = self.get_weight(indices)
|
||||||
|
# ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154
|
||||||
|
batch.weight = weight / np.max(weight) if self._weight_norm else weight
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def set_beta(self, beta: float) -> None:
|
||||||
|
self._beta = beta
|
||||||
|
@ -55,3 +55,7 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
|
|||||||
PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)
|
PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)
|
||||||
]
|
]
|
||||||
super().__init__(buffer_list)
|
super().__init__(buffer_list)
|
||||||
|
|
||||||
|
def set_beta(self, beta: float) -> None:
|
||||||
|
for buffer in self.buffers:
|
||||||
|
buffer.set_beta(beta)
|
||||||
|
@ -2,6 +2,7 @@ from tianshou.policy.base import BasePolicy
|
|||||||
from tianshou.policy.random import RandomPolicy
|
from tianshou.policy.random import RandomPolicy
|
||||||
from tianshou.policy.modelfree.dqn import DQNPolicy
|
from tianshou.policy.modelfree.dqn import DQNPolicy
|
||||||
from tianshou.policy.modelfree.c51 import C51Policy
|
from tianshou.policy.modelfree.c51 import C51Policy
|
||||||
|
from tianshou.policy.modelfree.rainbow import RainbowPolicy
|
||||||
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
||||||
from tianshou.policy.modelfree.iqn import IQNPolicy
|
from tianshou.policy.modelfree.iqn import IQNPolicy
|
||||||
from tianshou.policy.modelfree.fqf import FQFPolicy
|
from tianshou.policy.modelfree.fqf import FQFPolicy
|
||||||
@ -27,6 +28,7 @@ __all__ = [
|
|||||||
"RandomPolicy",
|
"RandomPolicy",
|
||||||
"DQNPolicy",
|
"DQNPolicy",
|
||||||
"C51Policy",
|
"C51Policy",
|
||||||
|
"RainbowPolicy",
|
||||||
"QRDQNPolicy",
|
"QRDQNPolicy",
|
||||||
"IQNPolicy",
|
"IQNPolicy",
|
||||||
"FQFPolicy",
|
"FQFPolicy",
|
||||||
|
37
tianshou/policy/modelfree/rainbow.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from tianshou.policy import C51Policy
|
||||||
|
from tianshou.data import Batch
|
||||||
|
from tianshou.utils.net.discrete import sample_noise
|
||||||
|
|
||||||
|
|
||||||
|
class RainbowPolicy(C51Policy):
|
||||||
|
"""Implementation of Rainbow DQN. arXiv:1710.02298.
|
||||||
|
|
||||||
|
:param torch.nn.Module model: a model following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||||
|
:param float discount_factor: in [0, 1].
|
||||||
|
:param int num_atoms: the number of atoms in the support set of the
|
||||||
|
value distribution. Default to 51.
|
||||||
|
:param float v_min: the value of the smallest atom in the support set.
|
||||||
|
Default to -10.0.
|
||||||
|
:param float v_max: the value of the largest atom in the support set.
|
||||||
|
Default to 10.0.
|
||||||
|
:param int estimation_step: the number of steps to look ahead. Default to 1.
|
||||||
|
:param int target_update_freq: the target network update frequency (0 if
|
||||||
|
you do not use the target network). Default to 0.
|
||||||
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
|
Default to False.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.policy.C51Policy` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||||
|
sample_noise(self.model)
|
||||||
|
if self._target and sample_noise(self.model_old):
|
||||||
|
self.model_old.train() # so that NoisyLinear takes effect
|
||||||
|
return super().learn(batch, **kwargs)
|
@ -11,10 +11,11 @@ def miniblock(
|
|||||||
output_size: int = 0,
|
output_size: int = 0,
|
||||||
norm_layer: Optional[ModuleType] = None,
|
norm_layer: Optional[ModuleType] = None,
|
||||||
activation: Optional[ModuleType] = None,
|
activation: Optional[ModuleType] = None,
|
||||||
|
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||||
) -> List[nn.Module]:
|
) -> List[nn.Module]:
|
||||||
"""Construct a miniblock with given input/output-size, norm layer and \
|
"""Construct a miniblock with given input/output-size, norm layer and \
|
||||||
activation."""
|
activation."""
|
||||||
layers: List[nn.Module] = [nn.Linear(input_size, output_size)]
|
layers: List[nn.Module] = [linear_layer(input_size, output_size)]
|
||||||
if norm_layer is not None:
|
if norm_layer is not None:
|
||||||
layers += [norm_layer(output_size)] # type: ignore
|
layers += [norm_layer(output_size)] # type: ignore
|
||||||
if activation is not None:
|
if activation is not None:
|
||||||
@ -42,6 +43,8 @@ class MLP(nn.Module):
|
|||||||
the same actvition for all layers if passed in nn.Module, or different
|
the same actvition for all layers if passed in nn.Module, or different
|
||||||
activation for different Modules if passed in a list. Default to
|
activation for different Modules if passed in a list. Default to
|
||||||
nn.ReLU.
|
nn.ReLU.
|
||||||
|
:param device: which device to create this model on. Default to None.
|
||||||
|
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -52,6 +55,7 @@ class MLP(nn.Module):
|
|||||||
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
||||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
||||||
device: Optional[Union[str, int, torch.device]] = None,
|
device: Optional[Union[str, int, torch.device]] = None,
|
||||||
|
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -78,9 +82,9 @@ class MLP(nn.Module):
|
|||||||
for in_dim, out_dim, norm, activ in zip(
|
for in_dim, out_dim, norm, activ in zip(
|
||||||
hidden_sizes[:-1], hidden_sizes[1:],
|
hidden_sizes[:-1], hidden_sizes[1:],
|
||||||
norm_layer_list, activation_list):
|
norm_layer_list, activation_list):
|
||||||
model += miniblock(in_dim, out_dim, norm, activ)
|
model += miniblock(in_dim, out_dim, norm, activ, linear_layer)
|
||||||
if output_dim > 0:
|
if output_dim > 0:
|
||||||
model += [nn.Linear(hidden_sizes[-1], output_dim)]
|
model += [linear_layer(hidden_sizes[-1], output_dim)]
|
||||||
self.output_dim = output_dim or hidden_sizes[-1]
|
self.output_dim = output_dim or hidden_sizes[-1]
|
||||||
self.model = nn.Sequential(*model)
|
self.model = nn.Sequential(*model)
|
||||||
|
|
||||||
@ -168,10 +172,10 @@ class Net(nn.Module):
|
|||||||
q_output_dim, v_output_dim = action_dim, num_atoms
|
q_output_dim, v_output_dim = action_dim, num_atoms
|
||||||
q_kwargs: Dict[str, Any] = {
|
q_kwargs: Dict[str, Any] = {
|
||||||
**q_kwargs, "input_dim": self.output_dim,
|
**q_kwargs, "input_dim": self.output_dim,
|
||||||
"output_dim": q_output_dim}
|
"output_dim": q_output_dim, "device": self.device}
|
||||||
v_kwargs: Dict[str, Any] = {
|
v_kwargs: Dict[str, Any] = {
|
||||||
**v_kwargs, "input_dim": self.output_dim,
|
**v_kwargs, "input_dim": self.output_dim,
|
||||||
"output_dim": v_output_dim}
|
"output_dim": v_output_dim, "device": self.device}
|
||||||
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
|
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
|
||||||
self.output_dim = self.Q.output_dim
|
self.output_dim = self.Q.output_dim
|
||||||
|
|
||||||
|
@ -307,3 +307,83 @@ class FullQuantileFunction(ImplicitQuantileNetwork):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
|
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
|
||||||
return (quantiles, fractions, quantiles_tau), h
|
return (quantiles, fractions, quantiles_tau), h
|
||||||
|
|
||||||
|
|
||||||
|
class NoisyLinear(nn.Module):
|
||||||
|
"""Implementation of Noisy Networks. arXiv:1706.10295.
|
||||||
|
|
||||||
|
:param int in_features: the number of input features.
|
||||||
|
:param int out_features: the number of output features.
|
||||||
|
:param float noisy_std: initial standard deviation of noisy linear layers.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
|
||||||
|
/fqf_iqn_qrdqn/network.py .
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_features: int, out_features: int, noisy_std: float = 0.5
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Learnable parameters.
|
||||||
|
self.mu_W = nn.Parameter(
|
||||||
|
torch.FloatTensor(out_features, in_features))
|
||||||
|
self.sigma_W = nn.Parameter(
|
||||||
|
torch.FloatTensor(out_features, in_features))
|
||||||
|
self.mu_bias = nn.Parameter(torch.FloatTensor(out_features))
|
||||||
|
self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features))
|
||||||
|
|
||||||
|
# Factorized noise parameters.
|
||||||
|
self.register_buffer('eps_p', torch.FloatTensor(in_features))
|
||||||
|
self.register_buffer('eps_q', torch.FloatTensor(out_features))
|
||||||
|
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.sigma = noisy_std
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
self.sample()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
bound = 1 / np.sqrt(self.in_features)
|
||||||
|
self.mu_W.data.uniform_(-bound, bound)
|
||||||
|
self.mu_bias.data.uniform_(-bound, bound)
|
||||||
|
self.sigma_W.data.fill_(self.sigma / np.sqrt(self.in_features))
|
||||||
|
self.sigma_bias.data.fill_(self.sigma / np.sqrt(self.in_features))
|
||||||
|
|
||||||
|
def f(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = torch.randn(x.size(0), device=x.device)
|
||||||
|
return x.sign().mul_(x.abs().sqrt_())
|
||||||
|
|
||||||
|
def sample(self) -> None:
|
||||||
|
self.eps_p.copy_(self.f(self.eps_p)) # type: ignore
|
||||||
|
self.eps_q.copy_(self.f(self.eps_q)) # type: ignore
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.training:
|
||||||
|
weight = self.mu_W + self.sigma_W * (
|
||||||
|
self.eps_q.ger(self.eps_p) # type: ignore
|
||||||
|
)
|
||||||
|
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore
|
||||||
|
else:
|
||||||
|
weight = self.mu_W
|
||||||
|
bias = self.mu_bias
|
||||||
|
|
||||||
|
return F.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_noise(model: nn.Module) -> bool:
|
||||||
|
"""Sample the random noises of NoisyLinear modules in the model.
|
||||||
|
|
||||||
|
:param model: a PyTorch module which may have NoisyLinear submodules.
|
||||||
|
:returns: True if model has at least one NoisyLinear submodule;
|
||||||
|
otherwise, False.
|
||||||
|
"""
|
||||||
|
done = False
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, NoisyLinear):
|
||||||
|
m.sample()
|
||||||
|
done = True
|
||||||
|
return done
|
||||||
|