Add C51 algorithm (#266)

This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887

1. add C51 policy in tianshou/policy/modelfree/c51.py.
2. add C51 net in tianshou/utils/net/discrete.py.
3. add C51 atari example in examples/atari/atari_c51.py.
4. add C51 statement in tianshou/policy/__init__.py.
5. add C51 test in test/discrete/test_c51.py.
6. add C51 atari results in examples/atari/results/c51/.

By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get  best_result': '20.50 ± 0.50', in epoch 9.

By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
This commit is contained in:
wizardsheng 2021-01-06 10:17:45 +08:00 committed by GitHub
parent 5d13d8a453
commit c6f2648e87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 533 additions and 11 deletions

View File

@ -23,6 +23,7 @@
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [C51](https://arxiv.org/pdf/1707.06887.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)

View File

@ -13,6 +13,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `C51 <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_

View File

@ -23,3 +23,19 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.
We haven't tuned this result to the best, so have fun with playing these hyperparameters!
# C51 (single run)
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20 | ![](results/c51/Pong_rew.png) | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64` |
| BreakoutNoFrameskip-v4 | 536.6 | ![](results/c51/Breakout_rew.png) | `python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
| EnduroNoFrameskip-v4 | 1032 | ![](results/c51/Enduro_rew.png) | `python3 atari_c51.py --task "EnduroNoFrameskip-v4 " ` |
| QbertNoFrameskip-v4 | 16245 | ![](results/c51/Qbert_rew.png) | `python3 atari_c51.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 3133 | ![](results/c51/MsPacman_rew.png) | `python3 atari_c51.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` |
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.

155
examples/atari/atari_c51.py Normal file
View File

@ -0,0 +1,155 @@
import os
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import C51
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from atari_wrapper import wrap_deepmind
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)
def test_c51(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = C51(*args.state_shape, args.action_shape,
args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = C51Policy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack)
# collector
train_collector = Collector(policy, train_envs, buffer)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# watch agent's performance
def watch():
print("Testing agent ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)
if args.watch:
watch()
exit(0)
# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * 4)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)
pprint.pprint(result)
watch()
if __name__ == '__main__':
test_c51(get_args())

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@ -76,6 +76,10 @@ def target_q_fn(buffer, indice):
return torch.tensor(-buffer.rew[indice], dtype=torch.float32)
def target_q_fn_multidim(buffer, indice):
return target_q_fn(buffer, indice).unsqueeze(1).repeat(1, 51)
def compute_nstep_return_base(nstep, gamma, buffer, indice):
returns = np.zeros_like(indice, dtype=np.float)
buf_len = len(buffer)
@ -108,6 +112,10 @@ def test_nstep_returns(size=10000):
assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
r_ = compute_nstep_return_base(1, .1, buf, indice)
assert np.allclose(returns, r_), (r_, returns)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 2
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns'))
@ -115,6 +123,10 @@ def test_nstep_returns(size=10000):
3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
r_ = compute_nstep_return_base(2, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 10
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns'))
@ -122,6 +134,10 @@ def test_nstep_returns(size=10000):
3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
r_ = compute_nstep_return_base(10, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
if __name__ == '__main__':
buf = ReplayBuffer(size)

View File

@ -4,7 +4,7 @@ import numpy as np
from tianshou.utils import MovAvg
from tianshou.utils import SummaryWriter
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import DQN
from tianshou.utils.net.discrete import DQN, C51
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@ -61,6 +61,10 @@ def test_net():
expect_output_shape = [bsz, *action_shape]
net = DQN(*state_shape, action_shape)
assert list(net(data)[0].shape) == expect_output_shape
num_atoms = 51
net = C51(*state_shape, action_shape, num_atoms)
expect_output_shape = [bsz, *action_shape, num_atoms]
assert list(net(data)[0].shape) == expect_output_shape
def test_summary_writer():

135
test/discrete/test_c51.py Normal file
View File

@ -0,0 +1,135 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', type=int, default=0)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_c51(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device,
softmax=True, num_atoms=args.num_atoms)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = C51Policy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay > 0:
buf = PrioritizedReplayBuffer(
args.buffer_size, alpha=args.alpha, beta=args.beta)
else:
buf = ReplayBuffer(args.buffer_size)
# collector
train_collector = Collector(policy, train_envs, buf)
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(epoch, env_step):
# eps annnealing, just a demo
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
def test_pc51(args=get_args()):
args.prioritized_replay = 1
args.gamma = .95
args.seed = 1
test_c51(args)
if __name__ == '__main__':
test_c51(get_args())

View File

@ -2,6 +2,7 @@ from tianshou.policy.base import BasePolicy
from tianshou.policy.random import RandomPolicy
from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.c51 import C51Policy
from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy
from tianshou.policy.modelfree.ddpg import DDPGPolicy
@ -18,6 +19,7 @@ __all__ = [
"RandomPolicy",
"ImitationPolicy",
"DQNPolicy",
"C51Policy",
"PGPolicy",
"A2CPolicy",
"DDPGPolicy",

View File

@ -245,7 +245,7 @@ class BasePolicy(ABC, nn.Module):
to False.
:return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with shape (bsz, ).
torch.Tensor with the same shape as target_q_fn's return tensor.
"""
rew = buffer.rew
if rew_norm:
@ -257,12 +257,11 @@ class BasePolicy(ABC, nn.Module):
mean, std = 0.0, 1.0
buf_len = len(buffer)
terminal = (indice + n_step - 1) % buf_len
target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, )
target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?)
target_q = to_numpy(target_q_torch)
target_q = _nstep_return(rew, buffer.done, target_q, indice,
gamma, n_step, len(buffer), mean, std)
batch.returns = to_torch_as(target_q, target_q_torch)
if hasattr(batch, "weight"): # prio buffer update
batch.weight = to_torch_as(batch.weight, target_q_torch)
@ -275,7 +274,7 @@ class BasePolicy(ABC, nn.Module):
i64 = np.array([0, 1], dtype=np.int64)
_episodic_return(f64, f64, b, 0.1, 0.1)
_episodic_return(f32, f64, b, 0.1, 0.1)
_nstep_return(f64, b, f32, i64, 0.1, 1, 4, 1.0, 0.0)
_nstep_return(f64, b, f32, i64, 0.1, 1, 4, 0.0, 1.0)
@njit
@ -311,13 +310,18 @@ def _nstep_return(
std: float,
) -> np.ndarray:
"""Numba speedup: 0.3s -> 0.15s."""
returns = np.zeros(indice.shape)
target_shape = target_q.shape
bsz = target_shape[0]
# change target_q to 2d array
target_q = target_q.reshape(bsz, -1)
returns = np.zeros(target_q.shape)
gammas = np.full(indice.shape, n_step)
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n
returns[done[now] > 0] = 0.0
returns = (rew[now] - mean) / std + gamma * returns
returns = (rew[now].reshape(-1, 1) - mean) / std + gamma * returns
target_q[gammas != n_step] = 0.0
gammas = gammas.reshape(-1, 1)
target_q = target_q * (gamma ** gammas) + returns
return target_q
return target_q.reshape(target_shape)

View File

@ -0,0 +1,143 @@
import torch
import numpy as np
from typing import Any, Dict, Union, Optional
from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy
class C51Policy(DQNPolicy):
"""Implementation of Categorical Deep Q-Network. arXiv:1707.06887.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int num_atoms: the number of atoms in the support set of the
value distribution, defaults to 51.
:param float v_min: the value of the smallest atom in the support set,
defaults to -10.0.
:param float v_max: the value of the largest atom in the support set,
defaults to 10.0.
:param int estimation_step: greater than 1, the number of steps to look
ahead.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
.. seealso::
Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed
explanation.
"""
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
num_atoms: int = 51,
v_min: float = -10.0,
v_max: float = 10.0,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(model, optim, discount_factor, estimation_step,
target_update_freq, reward_normalization, **kwargs)
assert num_atoms > 1, "num_atoms should be greater than 1"
assert v_min < v_max, "v_max should be larger than v_min"
self._num_atoms = num_atoms
self._v_min = v_min
self._v_max = v_max
self.support = torch.nn.Parameter(
torch.linspace(self._v_min, self._v_max, self._num_atoms),
requires_grad=False,
)
self.delta_z = (v_max - v_min) / (num_atoms - 1)
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms]
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
* ``act`` the action.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.DQNPolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
dist, h = model(obs_, state=state, info=batch.info)
q = (dist * self.support).sum(2)
act: np.ndarray = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
# some of actions are masked, they cannot be selected
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
# add eps to act in training or testing phase
if not self.updating and not np.isclose(self.eps, 0.0):
for i in range(len(q)):
if np.random.rand() < self.eps:
q_ = np.random.rand(*q[i].shape)
if hasattr(obs, "mask"):
q_[~obs.mask[i]] = -np.inf
act[i] = q_.argmax()
return Batch(logits=dist, act=act, state=h)
def _target_dist(self, batch: Batch) -> torch.Tensor:
if self._target:
a = self(batch, input="obs_next").act
next_dist = self(
batch, model="model_old", input="obs_next"
).logits
else:
next_b = self(batch, input="obs_next")
a = next_b.act
next_dist = next_b.logits
next_dist = next_dist[np.arange(len(a)), a, :]
target_support = batch.returns.clamp(self._v_min, self._v_max)
# An amazing trick for calculating the projection gracefully.
# ref: https://github.com/ShangtongZhang/DeepRL
target_dist = (1 - (target_support.unsqueeze(1) -
self.support.view(1, -1, 1)).abs() / self.delta_z
).clamp(0, 1) * next_dist.unsqueeze(1)
return target_dist.sum(-1)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
with torch.no_grad():
target_dist = self._target_dist(batch)
weight = batch.pop("weight", 1.0)
curr_dist = self(batch).logits
act = batch.act
curr_dist = curr_dist[np.arange(len(act)), act, :]
cross_entropy = - (target_dist * torch.log(curr_dist + 1e-8)).sum(1)
loss = (cross_entropy * weight).mean()
# ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100
batch.weight = cross_entropy.detach() # prio-buffer
loss.backward()
self.optim.step()
self._cnt += 1
return {"loss": loss.item()}

View File

@ -32,6 +32,8 @@ class Net(nn.Module):
(for Dueling DQN), defaults to False.
:param norm_layer: use which normalization before ReLU, e.g.,
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None.
:param int num_atoms: in order to expand to the net of distributional RL,
defaults to 1.
"""
def __init__(
@ -45,11 +47,14 @@ class Net(nn.Module):
hidden_layer_size: int = 128,
dueling: Optional[Tuple[int, int]] = None,
norm_layer: Optional[Callable[[int], nn.modules.Module]] = None,
num_atoms: int = 1,
) -> None:
super().__init__()
self.device = device
self.dueling = dueling
self.softmax = softmax
self.num_atoms = num_atoms
self.action_num = np.prod(action_shape)
input_size = np.prod(state_shape)
if concat:
input_size += np.prod(action_shape)
@ -62,7 +67,8 @@ class Net(nn.Module):
if dueling is None:
if action_shape and not concat:
model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
model += [nn.Linear(
hidden_layer_size, num_atoms * self.action_num)]
else: # dueling DQN
q_layer_num, v_layer_num = dueling
Q, V = [], []
@ -75,8 +81,9 @@ class Net(nn.Module):
hidden_layer_size, hidden_layer_size, norm_layer)
if action_shape and not concat:
Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
V += [nn.Linear(hidden_layer_size, 1)]
Q += [nn.Linear(
hidden_layer_size, num_atoms * self.action_num)]
V += [nn.Linear(hidden_layer_size, num_atoms)]
self.Q = nn.Sequential(*Q)
self.V = nn.Sequential(*V)
@ -94,7 +101,12 @@ class Net(nn.Module):
logits = self.model(s)
if self.dueling is not None: # Dueling DQN
q, v = self.Q(logits), self.V(logits)
if self.num_atoms > 1:
v = v.view(-1, 1, self.num_atoms)
q = q.view(-1, self.action_num, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
elif self.num_atoms > 1:
logits = logits.view(-1, self.action_num, self.num_atoms)
if self.softmax:
logits = torch.softmax(logits, dim=-1)
return logits, state

View File

@ -130,3 +130,36 @@ class DQN(nn.Module):
if not isinstance(x, torch.Tensor):
x = to_torch(x, device=self.device, dtype=torch.float32)
return self.net(x), state
class C51(DQN):
"""Reference: A distributional perspective on reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device)
self.action_shape = action_shape
self.num_atoms = num_atoms
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.num_atoms).softmax(dim=-1)
x = x.view(-1, np.prod(self.action_shape), self.num_atoms)
return x, state