Add Rainbow DQN (#386)
- add RainbowPolicy - add `set_beta` method in prio_buffer - add NoisyLinear in utils/network
@ -22,6 +22,7 @@
|
||||
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
|
||||
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
|
||||
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
||||
- [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf)
|
||||
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
||||
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
|
||||
- [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf)
|
||||
|
@ -30,6 +30,11 @@ DQN Family
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.RainbowPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.QRDQNPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -13,6 +13,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
|
||||
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||
* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN <https://arxiv.org/pdf/1707.02298.pdf>`_
|
||||
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
||||
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
|
||||
* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function <https://arxiv.org/pdf/1911.02140.pdf>`_
|
||||
|
@ -82,6 +82,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
| SeaquestNoFrameskip-v4 | 10775 |  | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 2482 |  | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
# Rainbow (single run)
|
||||
|
||||
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 21 |  | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch-size 64` |
|
||||
| BreakoutNoFrameskip-v4 | 684.6 |  | `python3 atari_rainbow.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
|
||||
| EnduroNoFrameskip-v4 | 1625.9 |  | `python3 atari_rainbow.py --task "EnduroNoFrameskip-v4"` |
|
||||
| QbertNoFrameskip-v4 | 16192.5 |  | `python3 atari_rainbow.py --task "QbertNoFrameskip-v4"` |
|
||||
| MsPacmanNoFrameskip-v4 | 3101 |  | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
|
||||
| SeaquestNoFrameskip-v4 | 2126 |  | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 1794.5 |  | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
# BCQ
|
||||
|
||||
To running BCQ algorithm on Atari, you need to do the following things:
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||
from tianshou.utils.net.discrete import NoisyLinear
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
@ -81,6 +82,65 @@ class C51(DQN):
|
||||
return x, state
|
||||
|
||||
|
||||
class Rainbow(DQN):
|
||||
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
action_shape: Sequence[int],
|
||||
num_atoms: int = 51,
|
||||
noisy_std: float = 0.5,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
is_dueling: bool = True,
|
||||
is_noisy: bool = True,
|
||||
) -> None:
|
||||
super().__init__(c, h, w, action_shape, device, features_only=True)
|
||||
self.action_num = np.prod(action_shape)
|
||||
self.num_atoms = num_atoms
|
||||
|
||||
def linear(x, y):
|
||||
if is_noisy:
|
||||
return NoisyLinear(x, y, noisy_std)
|
||||
else:
|
||||
return nn.Linear(x, y)
|
||||
|
||||
self.Q = nn.Sequential(
|
||||
linear(self.output_dim, 512), nn.ReLU(inplace=True),
|
||||
linear(512, self.action_num * self.num_atoms))
|
||||
self._is_dueling = is_dueling
|
||||
if self._is_dueling:
|
||||
self.V = nn.Sequential(
|
||||
linear(self.output_dim, 512), nn.ReLU(inplace=True),
|
||||
linear(512, self.num_atoms))
|
||||
self.output_dim = self.action_num * self.num_atoms
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Z(x, \*)."""
|
||||
x, state = super().forward(x)
|
||||
q = self.Q(x)
|
||||
q = q.view(-1, self.action_num, self.num_atoms)
|
||||
if self._is_dueling:
|
||||
v = self.V(x)
|
||||
v = v.view(-1, 1, self.num_atoms)
|
||||
logits = q - q.mean(dim=1, keepdim=True) + v
|
||||
else:
|
||||
logits = q
|
||||
y = logits.softmax(dim=2)
|
||||
return y, state
|
||||
|
||||
|
||||
class QRDQN(DQN):
|
||||
"""Reference: Distributional Reinforcement Learning with Quantile \
|
||||
Regression.
|
||||
|
204
examples/atari/atari_rainbow.py
Normal file
@ -0,0 +1,204 @@
|
||||
import os
|
||||
import torch
|
||||
import pprint
|
||||
import datetime
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import RainbowPolicy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||
|
||||
from atari_network import Rainbow
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0000625)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--num-atoms', type=int, default=51)
|
||||
parser.add_argument('--v-min', type=float, default=-10.)
|
||||
parser.add_argument('--v-max', type=float, default=10.)
|
||||
parser.add_argument('--noisy-std', type=float, default=0.1)
|
||||
parser.add_argument('--no-dueling', action='store_true', default=False)
|
||||
parser.add_argument('--no-noisy', action='store_true', default=False)
|
||||
parser.add_argument('--no-priority', action='store_true', default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.5)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument('--beta-final', type=float, default=1.)
|
||||
parser.add_argument('--beta-anneal-step', type=int, default=5000000)
|
||||
parser.add_argument('--no-weight-norm', action='store_true', default=False)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_atari_env(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||
|
||||
|
||||
def make_atari_env_watch(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||
episode_life=False, clip_rewards=False)
|
||||
|
||||
|
||||
def test_rainbow(args=get_args()):
|
||||
env = make_atari_env(args)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
|
||||
for _ in range(args.training_num)])
|
||||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# define model
|
||||
net = Rainbow(*args.state_shape, args.action_shape,
|
||||
args.num_atoms, args.noisy_std, args.device,
|
||||
is_dueling=not args.no_dueling,
|
||||
is_noisy=not args.no_noisy)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
# define policy
|
||||
policy = RainbowPolicy(
|
||||
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
|
||||
args.n_step, target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||
# when you have enough RAM
|
||||
if args.no_priority:
|
||||
buffer = VectorReplayBuffer(
|
||||
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
|
||||
save_only_last_obs=True, stack_num=args.frames_stack)
|
||||
else:
|
||||
buffer = PrioritizedVectorReplayBuffer(
|
||||
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
|
||||
save_only_last_obs=True, stack_num=args.frames_stack, alpha=args.alpha,
|
||||
beta=args.beta, weight_norm=not args.no_weight_norm)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# log
|
||||
log_path = os.path.join(
|
||||
args.logdir, args.task, 'rainbow',
|
||||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = args.eps_train - env_step / 1e6 * \
|
||||
(args.eps_train - args.eps_train_final)
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
if not args.no_priority:
|
||||
if env_step <= args.beta_anneal_step:
|
||||
beta = args.beta - env_step / args.beta_anneal_step * \
|
||||
(args.beta - args.beta_final)
|
||||
else:
|
||||
beta = args.beta_final
|
||||
buffer.set_beta(beta)
|
||||
logger.write('train/beta', env_step, beta)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
print(f"Generate buffer with size {args.buffer_size}")
|
||||
buffer = PrioritizedVectorReplayBuffer(
|
||||
args.buffer_size, buffer_num=len(test_envs),
|
||||
ignore_obs_next=True, save_only_last_obs=True,
|
||||
stack_num=args.frames_stack, alpha=args.alpha,
|
||||
beta=args.beta)
|
||||
collector = Collector(policy, test_envs, buffer,
|
||||
exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num,
|
||||
render=args.render)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step, test_in_train=False)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_rainbow(get_args())
|
BIN
examples/atari/results/rainbow/Breakout_rew.png
Normal file
After Width: | Height: | Size: 210 KiB |
BIN
examples/atari/results/rainbow/Enduro_rew.png
Normal file
After Width: | Height: | Size: 192 KiB |
BIN
examples/atari/results/rainbow/MsPacman_rew.png
Normal file
After Width: | Height: | Size: 233 KiB |
BIN
examples/atari/results/rainbow/Pong_rew.png
Normal file
After Width: | Height: | Size: 144 KiB |
BIN
examples/atari/results/rainbow/Qbert_rew.png
Normal file
After Width: | Height: | Size: 224 KiB |
BIN
examples/atari/results/rainbow/Seaquest_rew.png
Normal file
After Width: | Height: | Size: 198 KiB |
BIN
examples/atari/results/rainbow/SpaceInvaders_rew.png
Normal file
After Width: | Height: | Size: 226 KiB |
@ -193,7 +193,7 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
mask = np.isin(np.arange(buf2.maxsize), indices)
|
||||
assert np.all(weight[mask] == weight[mask][0])
|
||||
assert np.all(weight[~mask] == weight[~mask][0])
|
||||
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] < 1
|
||||
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
|
||||
|
||||
|
||||
def test_update():
|
||||
|
@ -54,6 +54,8 @@ def get_args():
|
||||
|
||||
def test_qrdqn(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
if args.task == 'CartPole-v0':
|
||||
env.spec.reward_threshold = 190 # lower the goal
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
|
@ -50,7 +50,7 @@ def test_discrete_cql(args=get_args()):
|
||||
# envs
|
||||
env = gym.make(args.task)
|
||||
if args.task == 'CartPole-v0':
|
||||
env.spec.reward_threshold = 190 # lower the goal
|
||||
env.spec.reward_threshold = 185 # lower the goal
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
test_envs = DummyVectorEnv(
|
||||
|
198
test/discrete/test_rainbow.py
Normal file
@ -0,0 +1,198 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import RainbowPolicy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.discrete import NoisyLinear
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--num-atoms', type=int, default=51)
|
||||
parser.add_argument('--v-min', type=float, default=-10.)
|
||||
parser.add_argument('--v-max', type=float, default=10.)
|
||||
parser.add_argument('--noisy-std', type=float, default=0.1)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=8000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=8)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.125)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay',
|
||||
action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument('--beta-final', type=float, default=1.)
|
||||
parser.add_argument('--resume', action="store_true")
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument("--save-interval", type=int, default=4)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_rainbow(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
|
||||
def noisy_linear(x, y):
|
||||
return NoisyLinear(x, y, args.noisy_std)
|
||||
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
softmax=True, num_atoms=args.num_atoms,
|
||||
dueling_param=({"linear_layer": noisy_linear},
|
||||
{"linear_layer": noisy_linear}))
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = RainbowPolicy(
|
||||
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
|
||||
args.n_step, target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# buffer
|
||||
if args.prioritized_replay:
|
||||
buf = PrioritizedVectorReplayBuffer(
|
||||
args.buffer_size, buffer_num=len(train_envs),
|
||||
alpha=args.alpha, beta=args.beta, weight_norm=True)
|
||||
else:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'rainbow')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = BasicLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# eps annealing, just a demo
|
||||
if env_step <= 10000:
|
||||
policy.set_eps(args.eps_train)
|
||||
elif env_step <= 50000:
|
||||
eps = args.eps_train - (env_step - 10000) / \
|
||||
40000 * (0.9 * args.eps_train)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(0.1 * args.eps_train)
|
||||
# beta annealing, just a demo
|
||||
if args.prioritized_replay:
|
||||
if env_step <= 10000:
|
||||
beta = args.beta
|
||||
elif env_step <= 50000:
|
||||
beta = args.beta - (env_step - 10000) / \
|
||||
40000 * (args.beta - args.beta_final)
|
||||
else:
|
||||
beta = args.beta_final
|
||||
buf.set_beta(beta)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
torch.save({
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth'))
|
||||
pickle.dump(train_collector.buffer,
|
||||
open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
policy.optim.load_state_dict(checkpoint['optim'])
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
|
||||
if os.path.exists(buffer_path):
|
||||
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
|
||||
print("Successfully restore buffer.")
|
||||
else:
|
||||
print("Fail to restore buffer.")
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
|
||||
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
def test_rainbow_resume(args=get_args()):
|
||||
args.resume = True
|
||||
test_rainbow(args)
|
||||
|
||||
|
||||
def test_prainbow(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
args.gamma = .95
|
||||
args.seed = 1
|
||||
test_rainbow(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_rainbow(get_args())
|
@ -10,13 +10,22 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
|
||||
:param float alpha: the prioritization exponent.
|
||||
:param float beta: the importance sample soft coefficient.
|
||||
:param bool weight_norm: whether to normalize returned weights with the maximum
|
||||
weight value within the batch. Default to True.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
weight_norm: bool = True,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
# will raise KeyError in PrioritizedVectorReplayBuffer
|
||||
# super().__init__(size, **kwargs)
|
||||
ReplayBuffer.__init__(self, size, **kwargs)
|
||||
@ -27,6 +36,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self.weight = SegmentTree(size)
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
self.options.update(alpha=alpha, beta=beta)
|
||||
self._weight_norm = weight_norm
|
||||
|
||||
def init_weight(self, index: Union[int, np.ndarray]) -> None:
|
||||
self.weight[index] = self._max_prio ** self._alpha
|
||||
@ -83,5 +93,10 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
else:
|
||||
indices = index
|
||||
batch = super().__getitem__(indices)
|
||||
batch.weight = self.get_weight(indices)
|
||||
weight = self.get_weight(indices)
|
||||
# ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154
|
||||
batch.weight = weight / np.max(weight) if self._weight_norm else weight
|
||||
return batch
|
||||
|
||||
def set_beta(self, beta: float) -> None:
|
||||
self._beta = beta
|
||||
|
@ -55,3 +55,7 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
|
||||
PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)
|
||||
]
|
||||
super().__init__(buffer_list)
|
||||
|
||||
def set_beta(self, beta: float) -> None:
|
||||
for buffer in self.buffers:
|
||||
buffer.set_beta(beta)
|
||||
|
@ -2,6 +2,7 @@ from tianshou.policy.base import BasePolicy
|
||||
from tianshou.policy.random import RandomPolicy
|
||||
from tianshou.policy.modelfree.dqn import DQNPolicy
|
||||
from tianshou.policy.modelfree.c51 import C51Policy
|
||||
from tianshou.policy.modelfree.rainbow import RainbowPolicy
|
||||
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
||||
from tianshou.policy.modelfree.iqn import IQNPolicy
|
||||
from tianshou.policy.modelfree.fqf import FQFPolicy
|
||||
@ -27,6 +28,7 @@ __all__ = [
|
||||
"RandomPolicy",
|
||||
"DQNPolicy",
|
||||
"C51Policy",
|
||||
"RainbowPolicy",
|
||||
"QRDQNPolicy",
|
||||
"IQNPolicy",
|
||||
"FQFPolicy",
|
||||
|
37
tianshou/policy/modelfree/rainbow.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from tianshou.policy import C51Policy
|
||||
from tianshou.data import Batch
|
||||
from tianshou.utils.net.discrete import sample_noise
|
||||
|
||||
|
||||
class RainbowPolicy(C51Policy):
|
||||
"""Implementation of Rainbow DQN. arXiv:1710.02298.
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param int num_atoms: the number of atoms in the support set of the
|
||||
value distribution. Default to 51.
|
||||
:param float v_min: the value of the smallest atom in the support set.
|
||||
Default to -10.0.
|
||||
:param float v_max: the value of the largest atom in the support set.
|
||||
Default to 10.0.
|
||||
:param int estimation_step: the number of steps to look ahead. Default to 1.
|
||||
:param int target_update_freq: the target network update frequency (0 if
|
||||
you do not use the target network). Default to 0.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.C51Policy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
sample_noise(self.model)
|
||||
if self._target and sample_noise(self.model_old):
|
||||
self.model_old.train() # so that NoisyLinear takes effect
|
||||
return super().learn(batch, **kwargs)
|
@ -11,10 +11,11 @@ def miniblock(
|
||||
output_size: int = 0,
|
||||
norm_layer: Optional[ModuleType] = None,
|
||||
activation: Optional[ModuleType] = None,
|
||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||
) -> List[nn.Module]:
|
||||
"""Construct a miniblock with given input/output-size, norm layer and \
|
||||
activation."""
|
||||
layers: List[nn.Module] = [nn.Linear(input_size, output_size)]
|
||||
layers: List[nn.Module] = [linear_layer(input_size, output_size)]
|
||||
if norm_layer is not None:
|
||||
layers += [norm_layer(output_size)] # type: ignore
|
||||
if activation is not None:
|
||||
@ -42,6 +43,8 @@ class MLP(nn.Module):
|
||||
the same actvition for all layers if passed in nn.Module, or different
|
||||
activation for different Modules if passed in a list. Default to
|
||||
nn.ReLU.
|
||||
:param device: which device to create this model on. Default to None.
|
||||
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -52,6 +55,7 @@ class MLP(nn.Module):
|
||||
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
||||
device: Optional[Union[str, int, torch.device]] = None,
|
||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
@ -78,9 +82,9 @@ class MLP(nn.Module):
|
||||
for in_dim, out_dim, norm, activ in zip(
|
||||
hidden_sizes[:-1], hidden_sizes[1:],
|
||||
norm_layer_list, activation_list):
|
||||
model += miniblock(in_dim, out_dim, norm, activ)
|
||||
model += miniblock(in_dim, out_dim, norm, activ, linear_layer)
|
||||
if output_dim > 0:
|
||||
model += [nn.Linear(hidden_sizes[-1], output_dim)]
|
||||
model += [linear_layer(hidden_sizes[-1], output_dim)]
|
||||
self.output_dim = output_dim or hidden_sizes[-1]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
@ -168,10 +172,10 @@ class Net(nn.Module):
|
||||
q_output_dim, v_output_dim = action_dim, num_atoms
|
||||
q_kwargs: Dict[str, Any] = {
|
||||
**q_kwargs, "input_dim": self.output_dim,
|
||||
"output_dim": q_output_dim}
|
||||
"output_dim": q_output_dim, "device": self.device}
|
||||
v_kwargs: Dict[str, Any] = {
|
||||
**v_kwargs, "input_dim": self.output_dim,
|
||||
"output_dim": v_output_dim}
|
||||
"output_dim": v_output_dim, "device": self.device}
|
||||
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
|
||||
self.output_dim = self.Q.output_dim
|
||||
|
||||
|
@ -307,3 +307,83 @@ class FullQuantileFunction(ImplicitQuantileNetwork):
|
||||
with torch.no_grad():
|
||||
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
|
||||
return (quantiles, fractions, quantiles_tau), h
|
||||
|
||||
|
||||
class NoisyLinear(nn.Module):
|
||||
"""Implementation of Noisy Networks. arXiv:1706.10295.
|
||||
|
||||
:param int in_features: the number of input features.
|
||||
:param int out_features: the number of output features.
|
||||
:param float noisy_std: initial standard deviation of noisy linear layers.
|
||||
|
||||
.. note::
|
||||
|
||||
Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
|
||||
/fqf_iqn_qrdqn/network.py .
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int, noisy_std: float = 0.5
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Learnable parameters.
|
||||
self.mu_W = nn.Parameter(
|
||||
torch.FloatTensor(out_features, in_features))
|
||||
self.sigma_W = nn.Parameter(
|
||||
torch.FloatTensor(out_features, in_features))
|
||||
self.mu_bias = nn.Parameter(torch.FloatTensor(out_features))
|
||||
self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features))
|
||||
|
||||
# Factorized noise parameters.
|
||||
self.register_buffer('eps_p', torch.FloatTensor(in_features))
|
||||
self.register_buffer('eps_q', torch.FloatTensor(out_features))
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.sigma = noisy_std
|
||||
|
||||
self.reset()
|
||||
self.sample()
|
||||
|
||||
def reset(self) -> None:
|
||||
bound = 1 / np.sqrt(self.in_features)
|
||||
self.mu_W.data.uniform_(-bound, bound)
|
||||
self.mu_bias.data.uniform_(-bound, bound)
|
||||
self.sigma_W.data.fill_(self.sigma / np.sqrt(self.in_features))
|
||||
self.sigma_bias.data.fill_(self.sigma / np.sqrt(self.in_features))
|
||||
|
||||
def f(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.randn(x.size(0), device=x.device)
|
||||
return x.sign().mul_(x.abs().sqrt_())
|
||||
|
||||
def sample(self) -> None:
|
||||
self.eps_p.copy_(self.f(self.eps_p)) # type: ignore
|
||||
self.eps_q.copy_(self.f(self.eps_q)) # type: ignore
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.training:
|
||||
weight = self.mu_W + self.sigma_W * (
|
||||
self.eps_q.ger(self.eps_p) # type: ignore
|
||||
)
|
||||
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore
|
||||
else:
|
||||
weight = self.mu_W
|
||||
bias = self.mu_bias
|
||||
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
|
||||
def sample_noise(model: nn.Module) -> bool:
|
||||
"""Sample the random noises of NoisyLinear modules in the model.
|
||||
|
||||
:param model: a PyTorch module which may have NoisyLinear submodules.
|
||||
:returns: True if model has at least one NoisyLinear submodule;
|
||||
otherwise, False.
|
||||
"""
|
||||
done = False
|
||||
for m in model.modules():
|
||||
if isinstance(m, NoisyLinear):
|
||||
m.sample()
|
||||
done = True
|
||||
return done
|
||||
|