Add QR-DQN algorithm (#276)
This is the PR for QR-DQN algorithm: https://arxiv.org/abs/1710.10044 1. add QR-DQN policy in tianshou/policy/modelfree/qrdqn.py. 2. add QR-DQN net in examples/atari/atari_network.py. 3. add QR-DQN atari example in examples/atari/atari_qrdqn.py. 4. add QR-DQN statement in tianshou/policy/init.py. 5. add QR-DQN unit test in test/discrete/test_qrdqn.py. 6. add QR-DQN atari results in examples/atari/results/qrdqn/. 7. add compute_q_value in DQNPolicy and C51Policy for simplify forward function. 8. move `with torch.no_grad():` from `_target_q` to BasePolicy By running "python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '19.8 ± 0.40', in epoch 8.
@ -24,6 +24,7 @@
|
||||
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
|
||||
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
|
||||
- [C51](https://arxiv.org/pdf/1707.06887.pdf)
|
||||
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||
|
||||
@ -14,6 +14,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
|
||||
* :class:`~tianshou.policy.C51Policy` `C51 <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
|
||||
@ -40,6 +40,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
|
||||
Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.
|
||||
|
||||
# QRDQN (single run)
|
||||
|
||||
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20 |  | `python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64` |
|
||||
| BreakoutNoFrameskip-v4 | 409.2 |  | `python3 atari_qrdqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
|
||||
| EnduroNoFrameskip-v4 | 1055.9 |  | `python3 atari_qrdqn.py --task "EnduroNoFrameskip-v4"` |
|
||||
| QbertNoFrameskip-v4 | 14990 |  | `python3 atari_qrdqn.py --task "QbertNoFrameskip-v4"` |
|
||||
| MsPacmanNoFrameskip-v4 | 2886 |  | `python3 atari_qrdqn.py --task "MsPacmanNoFrameskip-v4"` |
|
||||
| SeaquestNoFrameskip-v4 | 5676 |  | `python3 atari_qrdqn.py --task "SeaquestNoFrameskip-v4"` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 938 |  | `python3 atari_qrdqn.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
# BCQ
|
||||
|
||||
TODO: after the `done` issue fixed, the result should be re-tuned and place here.
|
||||
@ -49,4 +63,3 @@ To running BCQ algorithm on Atari, you need to do the following things:
|
||||
- Train an expert, by using the command listed in the above DQN section;
|
||||
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
||||
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.
|
||||
|
||||
|
||||
@ -64,8 +64,8 @@ class C51(DQN):
|
||||
num_atoms: int = 51,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device)
|
||||
self.action_shape = action_shape
|
||||
self.action_num = np.prod(action_shape)
|
||||
super().__init__(c, h, w, [self.action_num * num_atoms], device)
|
||||
self.num_atoms = num_atoms
|
||||
|
||||
def forward(
|
||||
@ -77,5 +77,38 @@ class C51(DQN):
|
||||
r"""Mapping: x -> Z(x, \*)."""
|
||||
x, state = super().forward(x)
|
||||
x = x.view(-1, self.num_atoms).softmax(dim=-1)
|
||||
x = x.view(-1, np.prod(self.action_shape), self.num_atoms)
|
||||
x = x.view(-1, self.action_num, self.num_atoms)
|
||||
return x, state
|
||||
|
||||
|
||||
class QRDQN(DQN):
|
||||
"""Reference: Distributional Reinforcement Learning with Quantile \
|
||||
Regression.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
action_shape: Sequence[int],
|
||||
num_quantiles: int = 200,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
self.action_num = np.prod(action_shape)
|
||||
super().__init__(c, h, w, [self.action_num * num_quantiles], device)
|
||||
self.num_quantiles = num_quantiles
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Z(x, \*)."""
|
||||
x, state = super().forward(x)
|
||||
x = x.view(-1, self.action_num, self.num_quantiles)
|
||||
return x, state
|
||||
|
||||
153
examples/atari/atari_qrdqn.py
Normal file
@ -0,0 +1,153 @@
|
||||
import os
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import QRDQNPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
from atari_network import QRDQN
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--num-quantiles', type=int, default=200)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_atari_env(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||
|
||||
|
||||
def make_atari_env_watch(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||
episode_life=False, clip_rewards=False)
|
||||
|
||||
|
||||
def test_qrdqn(args=get_args()):
|
||||
env = make_atari_env(args)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
|
||||
for _ in range(args.training_num)])
|
||||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# define model
|
||||
net = QRDQN(*args.state_shape, args.action_shape,
|
||||
args.num_quantiles, args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
# define policy
|
||||
policy = QRDQNPolicy(
|
||||
net, optim, args.gamma, args.num_quantiles,
|
||||
args.n_step, target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(
|
||||
args.resume_path, map_location=args.device
|
||||
))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||
# when you have enough RAM
|
||||
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
|
||||
save_only_last_obs=True, stack_num=args.frames_stack)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = args.eps_train - env_step / 1e6 * \
|
||||
(args.eps_train - args.eps_train_final)
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
writer.add_scalar('train/eps', eps, global_step=env_step)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Testing agent ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=[1] * args.test_num,
|
||||
render=args.render)
|
||||
pprint.pprint(result)
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.collect(n_step=args.batch_size * 4)
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_qrdqn(get_args())
|
||||
BIN
examples/atari/results/qrdqn/Breakout_rew.png
Normal file
|
After Width: | Height: | Size: 62 KiB |
BIN
examples/atari/results/qrdqn/Enduro_rew.png
Normal file
|
After Width: | Height: | Size: 67 KiB |
BIN
examples/atari/results/qrdqn/MsPacman_rew.png
Normal file
|
After Width: | Height: | Size: 53 KiB |
BIN
examples/atari/results/qrdqn/Pong_rew.png
Normal file
|
After Width: | Height: | Size: 37 KiB |
BIN
examples/atari/results/qrdqn/Qbert_rew.png
Normal file
|
After Width: | Height: | Size: 67 KiB |
BIN
examples/atari/results/qrdqn/Seaquest_rew.png
Normal file
|
After Width: | Height: | Size: 58 KiB |
BIN
examples/atari/results/qrdqn/SpaceInvader_rew.png
Normal file
|
After Width: | Height: | Size: 55 KiB |
136
test/discrete/test_qrdqn.py
Normal file
@ -0,0 +1,136 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import QRDQNPolicy
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--num-quantiles', type=int, default=200)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay',
|
||||
action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_qrdqn(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
softmax=False, num_atoms=args.num_quantiles)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = QRDQNPolicy(
|
||||
net, optim, args.gamma, args.num_quantiles,
|
||||
args.n_step, target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# buffer
|
||||
if args.prioritized_replay:
|
||||
buf = PrioritizedReplayBuffer(
|
||||
args.buffer_size, alpha=args.alpha, beta=args.beta)
|
||||
else:
|
||||
buf = ReplayBuffer(args.buffer_size)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buf)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# eps annnealing, just a demo
|
||||
if env_step <= 10000:
|
||||
policy.set_eps(args.eps_train)
|
||||
elif env_step <= 50000:
|
||||
eps = args.eps_train - (env_step - 10000) / \
|
||||
40000 * (0.9 * args.eps_train)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(0.1 * args.eps_train)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
|
||||
|
||||
def test_pqrdqn(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
args.gamma = .95
|
||||
args.seed = 1
|
||||
test_qrdqn(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_pqrdqn(get_args())
|
||||
@ -3,6 +3,7 @@ from tianshou.policy.random import RandomPolicy
|
||||
from tianshou.policy.imitation.base import ImitationPolicy
|
||||
from tianshou.policy.modelfree.dqn import DQNPolicy
|
||||
from tianshou.policy.modelfree.c51 import C51Policy
|
||||
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
||||
from tianshou.policy.modelfree.pg import PGPolicy
|
||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||
from tianshou.policy.modelfree.ddpg import DDPGPolicy
|
||||
@ -21,6 +22,7 @@ __all__ = [
|
||||
"ImitationPolicy",
|
||||
"DQNPolicy",
|
||||
"C51Policy",
|
||||
"QRDQNPolicy",
|
||||
"PGPolicy",
|
||||
"A2CPolicy",
|
||||
"DDPGPolicy",
|
||||
|
||||
@ -257,6 +257,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
mean, std = 0.0, 1.0
|
||||
buf_len = len(buffer)
|
||||
terminal = (indice + n_step - 1) % buf_len
|
||||
with torch.no_grad():
|
||||
target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?)
|
||||
target_q = to_numpy(target_q_torch)
|
||||
|
||||
|
||||
@ -74,7 +74,6 @@ class DiscreteBCQPolicy(DQNPolicy):
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
with torch.no_grad():
|
||||
act = self(batch, input="obs_next", eps=0.0).act
|
||||
target_q, _ = self.model_old(batch.obs_next)
|
||||
target_q = target_q[np.arange(len(act)), act]
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Union, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class C51Policy(DQNPolicy):
|
||||
@ -63,46 +63,9 @@ class C51Policy(DQNPolicy):
|
||||
) -> torch.Tensor:
|
||||
return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
|
||||
|
||||
* ``act`` the action.
|
||||
* ``state`` the hidden state.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :meth:`~tianshou.policy.DQNPolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
dist, h = model(obs_, state=state, info=batch.info)
|
||||
q = (dist * self.support).sum(2)
|
||||
act: np.ndarray = to_numpy(q.max(dim=1)[1])
|
||||
if hasattr(obs, "mask"):
|
||||
# some of actions are masked, they cannot be selected
|
||||
q_: np.ndarray = to_numpy(q)
|
||||
q_[~obs.mask] = -np.inf
|
||||
act = q_.argmax(axis=1)
|
||||
# add eps to act in training or testing phase
|
||||
if not self.updating and not np.isclose(self.eps, 0.0):
|
||||
for i in range(len(q)):
|
||||
if np.random.rand() < self.eps:
|
||||
q_ = np.random.rand(*q[i].shape)
|
||||
if hasattr(obs, "mask"):
|
||||
q_[~obs.mask[i]] = -np.inf
|
||||
act[i] = q_.argmax()
|
||||
return Batch(logits=dist, act=act, state=h)
|
||||
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the q value based on the network's raw output logits."""
|
||||
return (logits * self.support).sum(2)
|
||||
|
||||
def _target_dist(self, batch: Batch) -> torch.Tensor:
|
||||
if self._target:
|
||||
|
||||
@ -102,7 +102,6 @@ class DDPGPolicy(BasePolicy):
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
with torch.no_grad():
|
||||
target_q = self.critic_old(
|
||||
batch.obs_next,
|
||||
self(batch, model='actor_old', input='obs_next').act)
|
||||
|
||||
@ -78,7 +78,6 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input="obs_next")
|
||||
dist = obs_next_result.dist
|
||||
target_q = dist.probs * torch.min(
|
||||
|
||||
@ -79,7 +79,6 @@ class DQNPolicy(BasePolicy):
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
with torch.no_grad():
|
||||
if self._target:
|
||||
a = self(batch, input="obs_next").act
|
||||
target_q = self(
|
||||
@ -103,6 +102,10 @@ class DQNPolicy(BasePolicy):
|
||||
self._gamma, self._n_step, self._rew_norm)
|
||||
return batch
|
||||
|
||||
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the q value based on the network's raw output logits."""
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
@ -143,7 +146,8 @@ class DQNPolicy(BasePolicy):
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
q, h = model(obs_, state=state, info=batch.info)
|
||||
logits, h = model(obs_, state=state, info=batch.info)
|
||||
q = self.compute_q_value(logits)
|
||||
act: np.ndarray = to_numpy(q.max(dim=1)[1])
|
||||
if hasattr(obs, "mask"):
|
||||
# some of actions are masked, they cannot be selected
|
||||
@ -158,7 +162,7 @@ class DQNPolicy(BasePolicy):
|
||||
if hasattr(obs, "mask"):
|
||||
q_[~obs.mask[i]] = -np.inf
|
||||
act[i] = q_.argmax()
|
||||
return Batch(logits=q, act=act, state=h)
|
||||
return Batch(logits=logits, act=act, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
|
||||
94
tianshou/policy/modelfree/qrdqn.py
Normal file
@ -0,0 +1,94 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Any, Dict
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class QRDQNPolicy(DQNPolicy):
|
||||
"""Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param int num_quantiles: the number of quantile midpoints in the inverse
|
||||
cumulative distribution function of the value, defaults to 200.
|
||||
:param int estimation_step: greater than 1, the number of steps to look
|
||||
ahead.
|
||||
:param int target_update_freq: the target network update frequency (0 if
|
||||
you do not use the target network).
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: float = 0.99,
|
||||
num_quantiles: int = 200,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(model, optim, discount_factor, estimation_step,
|
||||
target_update_freq, reward_normalization, **kwargs)
|
||||
assert num_quantiles > 1, "num_quantiles should be greater than 1"
|
||||
self._num_quantiles = num_quantiles
|
||||
tau = torch.linspace(0, 1, self._num_quantiles + 1)
|
||||
self.tau_hat = torch.nn.Parameter(
|
||||
((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False)
|
||||
warnings.filterwarnings("ignore", message="Using a target size")
|
||||
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
if self._target:
|
||||
a = self(batch, input="obs_next").act
|
||||
next_dist = self(
|
||||
batch, model="model_old", input="obs_next"
|
||||
).logits
|
||||
else:
|
||||
next_b = self(batch, input="obs_next")
|
||||
a = next_b.act
|
||||
next_dist = next_b.logits
|
||||
next_dist = next_dist[np.arange(len(a)), a, :]
|
||||
return next_dist # shape: [bsz, num_quantiles]
|
||||
|
||||
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the q value based on the network's raw output logits."""
|
||||
return logits.mean(2)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
weight = batch.pop("weight", 1.0)
|
||||
curr_dist = self(batch).logits
|
||||
act = batch.act
|
||||
curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
|
||||
target_dist = batch.returns.unsqueeze(1)
|
||||
# calculate each element's difference between curr_dist and target_dist
|
||||
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
huber_loss = (u * (
|
||||
self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()
|
||||
).abs()).sum(-1).mean(1)
|
||||
loss = (huber_loss * weight).mean()
|
||||
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
|
||||
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._iter += 1
|
||||
return {"loss": loss.item()}
|
||||
@ -140,7 +140,6 @@ class SACPolicy(DDPGPolicy):
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input='obs_next')
|
||||
a_ = obs_next_result.act
|
||||
target_q = torch.min(
|
||||
|
||||
@ -104,7 +104,6 @@ class TD3Policy(DDPGPolicy):
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
a_ = self(batch, model="actor_old", input="obs_next").act
|
||||
dev = a_.device
|
||||
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
|
||||
|
||||