Add C51 algorithm (#266)
This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887 1. add C51 policy in tianshou/policy/modelfree/c51.py. 2. add C51 net in tianshou/utils/net/discrete.py. 3. add C51 atari example in examples/atari/atari_c51.py. 4. add C51 statement in tianshou/policy/__init__.py. 5. add C51 test in test/discrete/test_c51.py. 6. add C51 atari results in examples/atari/results/c51/. By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '20.50 ± 0.50', in epoch 9. By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
@ -23,6 +23,7 @@
|
||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
|
||||
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
|
||||
- [C51](https://arxiv.org/pdf/1707.06887.pdf)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||
|
@ -13,6 +13,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
|
||||
* :class:`~tianshou.policy.C51Policy` `C51 <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
|
@ -23,3 +23,19 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.
|
||||
|
||||
We haven't tuned this result to the best, so have fun with playing these hyperparameters!
|
||||
|
||||
# C51 (single run)
|
||||
|
||||
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20 |  | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64` |
|
||||
| BreakoutNoFrameskip-v4 | 536.6 |  | `python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
|
||||
| EnduroNoFrameskip-v4 | 1032 |  | `python3 atari_c51.py --task "EnduroNoFrameskip-v4 " ` |
|
||||
| QbertNoFrameskip-v4 | 16245 |  | `python3 atari_c51.py --task "QbertNoFrameskip-v4"` |
|
||||
| MsPacmanNoFrameskip-v4 | 3133 |  | `python3 atari_c51.py --task "MsPacmanNoFrameskip-v4"` |
|
||||
| SeaquestNoFrameskip-v4 | 6226 |  | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 988.5 |  | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
|
155
examples/atari/atari_c51.py
Normal file
@ -0,0 +1,155 @@
|
||||
import os
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import C51Policy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.discrete import C51
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--num-atoms', type=int, default=51)
|
||||
parser.add_argument('--v-min', type=float, default=-10.)
|
||||
parser.add_argument('--v-max', type=float, default=10.)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--frames_stack', type=int, default=4)
|
||||
parser.add_argument('--resume_path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_atari_env(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||
|
||||
|
||||
def make_atari_env_watch(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||
episode_life=False, clip_rewards=False)
|
||||
|
||||
|
||||
def test_c51(args=get_args()):
|
||||
env = make_atari_env(args)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
|
||||
for _ in range(args.training_num)])
|
||||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# define model
|
||||
net = C51(*args.state_shape, args.action_shape,
|
||||
args.num_atoms, args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
# define policy
|
||||
policy = C51Policy(
|
||||
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
|
||||
args.n_step, target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(
|
||||
args.resume_path, map_location=args.device
|
||||
))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||
# when you have enough RAM
|
||||
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
|
||||
save_only_last_obs=True, stack_num=args.frames_stack)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = args.eps_train - env_step / 1e6 * \
|
||||
(args.eps_train - args.eps_train_final)
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
writer.add_scalar('train/eps', eps, global_step=env_step)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Testing agent ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=[1] * args.test_num,
|
||||
render=args.render)
|
||||
pprint.pprint(result)
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.collect(n_step=args.batch_size * 4)
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_c51(get_args())
|
BIN
examples/atari/results/c51/Breakout_rew.png
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
examples/atari/results/c51/Enduro_rew.png
Normal file
After Width: | Height: | Size: 60 KiB |
BIN
examples/atari/results/c51/MsPacman_rew.png
Normal file
After Width: | Height: | Size: 53 KiB |
BIN
examples/atari/results/c51/Pong_rew.png
Normal file
After Width: | Height: | Size: 38 KiB |
BIN
examples/atari/results/c51/Qbert_rew.png
Normal file
After Width: | Height: | Size: 61 KiB |
BIN
examples/atari/results/c51/Seaquest_rew.png
Normal file
After Width: | Height: | Size: 52 KiB |
BIN
examples/atari/results/c51/SpaceInvader_rew.png
Normal file
After Width: | Height: | Size: 66 KiB |
@ -76,6 +76,10 @@ def target_q_fn(buffer, indice):
|
||||
return torch.tensor(-buffer.rew[indice], dtype=torch.float32)
|
||||
|
||||
|
||||
def target_q_fn_multidim(buffer, indice):
|
||||
return target_q_fn(buffer, indice).unsqueeze(1).repeat(1, 51)
|
||||
|
||||
|
||||
def compute_nstep_return_base(nstep, gamma, buffer, indice):
|
||||
returns = np.zeros_like(indice, dtype=np.float)
|
||||
buf_len = len(buffer)
|
||||
@ -108,6 +112,10 @@ def test_nstep_returns(size=10000):
|
||||
assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
|
||||
r_ = compute_nstep_return_base(1, .1, buf, indice)
|
||||
assert np.allclose(returns, r_), (r_, returns)
|
||||
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1
|
||||
).pop('returns'))
|
||||
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||
# test nstep = 2
|
||||
returns = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns'))
|
||||
@ -115,6 +123,10 @@ def test_nstep_returns(size=10000):
|
||||
3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
|
||||
r_ = compute_nstep_return_base(2, .1, buf, indice)
|
||||
assert np.allclose(returns, r_)
|
||||
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2
|
||||
).pop('returns'))
|
||||
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||
# test nstep = 10
|
||||
returns = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns'))
|
||||
@ -122,6 +134,10 @@ def test_nstep_returns(size=10000):
|
||||
3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
|
||||
r_ = compute_nstep_return_base(10, .1, buf, indice)
|
||||
assert np.allclose(returns, r_)
|
||||
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10
|
||||
).pop('returns'))
|
||||
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||
|
||||
if __name__ == '__main__':
|
||||
buf = ReplayBuffer(size)
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
from tianshou.utils import MovAvg
|
||||
from tianshou.utils import SummaryWriter
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.discrete import DQN
|
||||
from tianshou.utils.net.discrete import DQN, C51
|
||||
from tianshou.exploration import GaussianNoise, OUNoise
|
||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||
|
||||
@ -61,6 +61,10 @@ def test_net():
|
||||
expect_output_shape = [bsz, *action_shape]
|
||||
net = DQN(*state_shape, action_shape)
|
||||
assert list(net(data)[0].shape) == expect_output_shape
|
||||
num_atoms = 51
|
||||
net = C51(*state_shape, action_shape, num_atoms)
|
||||
expect_output_shape = [bsz, *action_shape, num_atoms]
|
||||
assert list(net(data)[0].shape) == expect_output_shape
|
||||
|
||||
|
||||
def test_summary_writer():
|
||||
|
135
test/discrete/test_c51.py
Normal file
@ -0,0 +1,135 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import C51Policy
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--num-atoms', type=int, default=51)
|
||||
parser.add_argument('--v-min', type=float, default=-10.)
|
||||
parser.add_argument('--v-max', type=float, default=10.)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay', type=int, default=0)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_c51(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device,
|
||||
softmax=True, num_atoms=args.num_atoms)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = C51Policy(
|
||||
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
|
||||
args.n_step, target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# buffer
|
||||
if args.prioritized_replay > 0:
|
||||
buf = PrioritizedReplayBuffer(
|
||||
args.buffer_size, alpha=args.alpha, beta=args.beta)
|
||||
else:
|
||||
buf = ReplayBuffer(args.buffer_size)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buf)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# eps annnealing, just a demo
|
||||
if env_step <= 10000:
|
||||
policy.set_eps(args.eps_train)
|
||||
elif env_step <= 50000:
|
||||
eps = args.eps_train - (env_step - 10000) / \
|
||||
40000 * (0.9 * args.eps_train)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(0.1 * args.eps_train)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
|
||||
|
||||
def test_pc51(args=get_args()):
|
||||
args.prioritized_replay = 1
|
||||
args.gamma = .95
|
||||
args.seed = 1
|
||||
test_c51(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_c51(get_args())
|
@ -2,6 +2,7 @@ from tianshou.policy.base import BasePolicy
|
||||
from tianshou.policy.random import RandomPolicy
|
||||
from tianshou.policy.imitation.base import ImitationPolicy
|
||||
from tianshou.policy.modelfree.dqn import DQNPolicy
|
||||
from tianshou.policy.modelfree.c51 import C51Policy
|
||||
from tianshou.policy.modelfree.pg import PGPolicy
|
||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||
from tianshou.policy.modelfree.ddpg import DDPGPolicy
|
||||
@ -18,6 +19,7 @@ __all__ = [
|
||||
"RandomPolicy",
|
||||
"ImitationPolicy",
|
||||
"DQNPolicy",
|
||||
"C51Policy",
|
||||
"PGPolicy",
|
||||
"A2CPolicy",
|
||||
"DDPGPolicy",
|
||||
|
@ -245,7 +245,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
to False.
|
||||
|
||||
:return: a Batch. The result will be stored in batch.returns as a
|
||||
torch.Tensor with shape (bsz, ).
|
||||
torch.Tensor with the same shape as target_q_fn's return tensor.
|
||||
"""
|
||||
rew = buffer.rew
|
||||
if rew_norm:
|
||||
@ -257,12 +257,11 @@ class BasePolicy(ABC, nn.Module):
|
||||
mean, std = 0.0, 1.0
|
||||
buf_len = len(buffer)
|
||||
terminal = (indice + n_step - 1) % buf_len
|
||||
target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, )
|
||||
target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?)
|
||||
target_q = to_numpy(target_q_torch)
|
||||
|
||||
target_q = _nstep_return(rew, buffer.done, target_q, indice,
|
||||
gamma, n_step, len(buffer), mean, std)
|
||||
|
||||
batch.returns = to_torch_as(target_q, target_q_torch)
|
||||
if hasattr(batch, "weight"): # prio buffer update
|
||||
batch.weight = to_torch_as(batch.weight, target_q_torch)
|
||||
@ -275,7 +274,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
i64 = np.array([0, 1], dtype=np.int64)
|
||||
_episodic_return(f64, f64, b, 0.1, 0.1)
|
||||
_episodic_return(f32, f64, b, 0.1, 0.1)
|
||||
_nstep_return(f64, b, f32, i64, 0.1, 1, 4, 1.0, 0.0)
|
||||
_nstep_return(f64, b, f32, i64, 0.1, 1, 4, 0.0, 1.0)
|
||||
|
||||
|
||||
@njit
|
||||
@ -311,13 +310,18 @@ def _nstep_return(
|
||||
std: float,
|
||||
) -> np.ndarray:
|
||||
"""Numba speedup: 0.3s -> 0.15s."""
|
||||
returns = np.zeros(indice.shape)
|
||||
target_shape = target_q.shape
|
||||
bsz = target_shape[0]
|
||||
# change target_q to 2d array
|
||||
target_q = target_q.reshape(bsz, -1)
|
||||
returns = np.zeros(target_q.shape)
|
||||
gammas = np.full(indice.shape, n_step)
|
||||
for n in range(n_step - 1, -1, -1):
|
||||
now = (indice + n) % buf_len
|
||||
gammas[done[now] > 0] = n
|
||||
returns[done[now] > 0] = 0.0
|
||||
returns = (rew[now] - mean) / std + gamma * returns
|
||||
returns = (rew[now].reshape(-1, 1) - mean) / std + gamma * returns
|
||||
target_q[gammas != n_step] = 0.0
|
||||
gammas = gammas.reshape(-1, 1)
|
||||
target_q = target_q * (gamma ** gammas) + returns
|
||||
return target_q
|
||||
return target_q.reshape(target_shape)
|
||||
|
143
tianshou/policy/modelfree/c51.py
Normal file
@ -0,0 +1,143 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy
|
||||
|
||||
|
||||
class C51Policy(DQNPolicy):
|
||||
"""Implementation of Categorical Deep Q-Network. arXiv:1707.06887.
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param int num_atoms: the number of atoms in the support set of the
|
||||
value distribution, defaults to 51.
|
||||
:param float v_min: the value of the smallest atom in the support set,
|
||||
defaults to -10.0.
|
||||
:param float v_max: the value of the largest atom in the support set,
|
||||
defaults to 10.0.
|
||||
:param int estimation_step: greater than 1, the number of steps to look
|
||||
ahead.
|
||||
:param int target_update_freq: the target network update frequency (0 if
|
||||
you do not use the target network).
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: float = 0.99,
|
||||
num_atoms: int = 51,
|
||||
v_min: float = -10.0,
|
||||
v_max: float = 10.0,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(model, optim, discount_factor, estimation_step,
|
||||
target_update_freq, reward_normalization, **kwargs)
|
||||
assert num_atoms > 1, "num_atoms should be greater than 1"
|
||||
assert v_min < v_max, "v_max should be larger than v_min"
|
||||
self._num_atoms = num_atoms
|
||||
self._v_min = v_min
|
||||
self._v_max = v_max
|
||||
self.support = torch.nn.Parameter(
|
||||
torch.linspace(self._v_min, self._v_max, self._num_atoms),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.delta_z = (v_max - v_min) / (num_atoms - 1)
|
||||
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
|
||||
|
||||
* ``act`` the action.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.DQNPolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
dist, h = model(obs_, state=state, info=batch.info)
|
||||
q = (dist * self.support).sum(2)
|
||||
act: np.ndarray = to_numpy(q.max(dim=1)[1])
|
||||
if hasattr(obs, "mask"):
|
||||
# some of actions are masked, they cannot be selected
|
||||
q_: np.ndarray = to_numpy(q)
|
||||
q_[~obs.mask] = -np.inf
|
||||
act = q_.argmax(axis=1)
|
||||
# add eps to act in training or testing phase
|
||||
if not self.updating and not np.isclose(self.eps, 0.0):
|
||||
for i in range(len(q)):
|
||||
if np.random.rand() < self.eps:
|
||||
q_ = np.random.rand(*q[i].shape)
|
||||
if hasattr(obs, "mask"):
|
||||
q_[~obs.mask[i]] = -np.inf
|
||||
act[i] = q_.argmax()
|
||||
return Batch(logits=dist, act=act, state=h)
|
||||
|
||||
def _target_dist(self, batch: Batch) -> torch.Tensor:
|
||||
if self._target:
|
||||
a = self(batch, input="obs_next").act
|
||||
next_dist = self(
|
||||
batch, model="model_old", input="obs_next"
|
||||
).logits
|
||||
else:
|
||||
next_b = self(batch, input="obs_next")
|
||||
a = next_b.act
|
||||
next_dist = next_b.logits
|
||||
next_dist = next_dist[np.arange(len(a)), a, :]
|
||||
target_support = batch.returns.clamp(self._v_min, self._v_max)
|
||||
# An amazing trick for calculating the projection gracefully.
|
||||
# ref: https://github.com/ShangtongZhang/DeepRL
|
||||
target_dist = (1 - (target_support.unsqueeze(1) -
|
||||
self.support.view(1, -1, 1)).abs() / self.delta_z
|
||||
).clamp(0, 1) * next_dist.unsqueeze(1)
|
||||
return target_dist.sum(-1)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._cnt % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
with torch.no_grad():
|
||||
target_dist = self._target_dist(batch)
|
||||
weight = batch.pop("weight", 1.0)
|
||||
curr_dist = self(batch).logits
|
||||
act = batch.act
|
||||
curr_dist = curr_dist[np.arange(len(act)), act, :]
|
||||
cross_entropy = - (target_dist * torch.log(curr_dist + 1e-8)).sum(1)
|
||||
loss = (cross_entropy * weight).mean()
|
||||
# ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100
|
||||
batch.weight = cross_entropy.detach() # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._cnt += 1
|
||||
return {"loss": loss.item()}
|
@ -32,6 +32,8 @@ class Net(nn.Module):
|
||||
(for Dueling DQN), defaults to False.
|
||||
:param norm_layer: use which normalization before ReLU, e.g.,
|
||||
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None.
|
||||
:param int num_atoms: in order to expand to the net of distributional RL,
|
||||
defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -45,11 +47,14 @@ class Net(nn.Module):
|
||||
hidden_layer_size: int = 128,
|
||||
dueling: Optional[Tuple[int, int]] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.modules.Module]] = None,
|
||||
num_atoms: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.dueling = dueling
|
||||
self.softmax = softmax
|
||||
self.num_atoms = num_atoms
|
||||
self.action_num = np.prod(action_shape)
|
||||
input_size = np.prod(state_shape)
|
||||
if concat:
|
||||
input_size += np.prod(action_shape)
|
||||
@ -62,7 +67,8 @@ class Net(nn.Module):
|
||||
|
||||
if dueling is None:
|
||||
if action_shape and not concat:
|
||||
model += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||
model += [nn.Linear(
|
||||
hidden_layer_size, num_atoms * self.action_num)]
|
||||
else: # dueling DQN
|
||||
q_layer_num, v_layer_num = dueling
|
||||
Q, V = [], []
|
||||
@ -75,8 +81,9 @@ class Net(nn.Module):
|
||||
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||
|
||||
if action_shape and not concat:
|
||||
Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))]
|
||||
V += [nn.Linear(hidden_layer_size, 1)]
|
||||
Q += [nn.Linear(
|
||||
hidden_layer_size, num_atoms * self.action_num)]
|
||||
V += [nn.Linear(hidden_layer_size, num_atoms)]
|
||||
|
||||
self.Q = nn.Sequential(*Q)
|
||||
self.V = nn.Sequential(*V)
|
||||
@ -94,7 +101,12 @@ class Net(nn.Module):
|
||||
logits = self.model(s)
|
||||
if self.dueling is not None: # Dueling DQN
|
||||
q, v = self.Q(logits), self.V(logits)
|
||||
if self.num_atoms > 1:
|
||||
v = v.view(-1, 1, self.num_atoms)
|
||||
q = q.view(-1, self.action_num, self.num_atoms)
|
||||
logits = q - q.mean(dim=1, keepdim=True) + v
|
||||
elif self.num_atoms > 1:
|
||||
logits = logits.view(-1, self.action_num, self.num_atoms)
|
||||
if self.softmax:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
return logits, state
|
||||
|
@ -130,3 +130,36 @@ class DQN(nn.Module):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = to_torch(x, device=self.device, dtype=torch.float32)
|
||||
return self.net(x), state
|
||||
|
||||
|
||||
class C51(DQN):
|
||||
"""Reference: A distributional perspective on reinforcement learning.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
action_shape: Sequence[int],
|
||||
num_atoms: int = 51,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device)
|
||||
self.action_shape = action_shape
|
||||
self.num_atoms = num_atoms
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Z(x, \*)."""
|
||||
x, state = super().forward(x)
|
||||
x = x.view(-1, self.num_atoms).softmax(dim=-1)
|
||||
x = x.view(-1, np.prod(self.action_shape), self.num_atoms)
|
||||
return x, state
|
||||
|