Add Implicit Quantile Network (#371)

This commit is contained in:
Yi Su 2021-05-28 18:44:23 -07:00 committed by GitHub
parent 458028a326
commit f3169b4c1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 554 additions and 3 deletions

View File

@ -24,6 +24,7 @@
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)

View File

@ -35,6 +35,11 @@ DQN Family
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.IQNPolicy
:members:
:undoc-members:
:show-inheritance:
On-policy
~~~~~~~~~

View File

@ -14,6 +14,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient <https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_

View File

@ -54,6 +54,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| SeaquestNoFrameskip-v4 | 5676 | ![](results/qrdqn/Seaquest_rew.png) | `python3 atari_qrdqn.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 938 | ![](results/qrdqn/SpaceInvader_rew.png) | `python3 atari_qrdqn.py --task "SpaceInvadersNoFrameskip-v4"` |
# IQN (single run)
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.9 | ![](results/iqn/Pong_rew.png) | `python3 atari_iqn.py --task "PongNoFrameskip-v4" --batch-size 64` |
| BreakoutNoFrameskip-v4 | 578.3 | ![](results/iqn/Breakout_rew.png) | `python3 atari_iqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
| EnduroNoFrameskip-v4 | 1507 | ![](results/iqn/Enduro_rew.png) | `python3 atari_iqn.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 15520 | ![](results/iqn/Qbert_rew.png) | `python3 atari_iqn.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 2911 | ![](results/iqn/MsPacman_rew.png) | `python3 atari_iqn.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 6236 | ![](results/iqn/Seaquest_rew.png) | `python3 atari_iqn.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 1370 | ![](results/iqn/SpaceInvader_rew.png) | `python3 atari_iqn.py --task "SpaceInvadersNoFrameskip-v4"` |
# BCQ
To running BCQ algorithm on Atari, you need to do the following things:

View File

@ -107,7 +107,7 @@ def test_c51(args=get_args()):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20

View File

@ -102,7 +102,7 @@ def test_dqn(args=get_args()):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20

183
examples/atari/atari_iqn.py Normal file
View File

@ -0,0 +1,183 @@
import os
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import IQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from atari_network import DQN
from atari_wrapper import wrap_deepmind
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--sample-size', type=int, default=32)
parser.add_argument('--online-sample-size', type=int, default=8)
parser.add_argument('--target-sample-size', type=int, default=8)
parser.add_argument('--num-cosines', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)
def test_iqn(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
feature_net = DQN(*args.state_shape, args.action_shape, args.device,
features_only=True)
net = ImplicitQuantileNetwork(
feature_net, args.action_shape, args.hidden_sizes,
num_cosines=args.num_cosines, device=args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = IQNPolicy(
net, optim, args.gamma, args.sample_size, args.online_sample_size,
args.target_sample_size, args.n_step,
target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs),
ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'iqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
if args.watch:
watch()
exit(0)
# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result)
watch()
if __name__ == '__main__':
test_iqn(get_args())

View File

@ -105,7 +105,7 @@ def test_qrdqn(args=get_args()):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20

Binary file not shown.

After

Width:  |  Height:  |  Size: 226 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 243 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 214 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 236 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 222 KiB

149
test/discrete/test_iqn.py Normal file
View File

@ -0,0 +1,149 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import IQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--sample-size', type=int, default=32)
parser.add_argument('--online-sample-size', type=int, default=8)
parser.add_argument('--target-sample-size', type=int, default=8)
parser.add_argument('--num-cosines', type=int, default=64)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[64, 64, 64])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay',
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_iqn(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
feature_net = Net(args.state_shape, args.hidden_sizes[-1],
hidden_sizes=args.hidden_sizes[:-1], device=args.device,
softmax=False)
net = ImplicitQuantileNetwork(
feature_net, args.action_shape,
num_cosines=args.num_cosines, device=args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = IQNPolicy(
net, optim, args.gamma, args.sample_size, args.online_sample_size,
args.target_sample_size, args.n_step,
target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs),
alpha=args.alpha, beta=args.beta)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'iqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(epoch, env_step):
# eps annnealing, just a demo
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_piqn(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
test_iqn(args)
if __name__ == '__main__':
test_iqn(get_args())

View File

@ -3,6 +3,7 @@ from tianshou.policy.random import RandomPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.c51 import C51Policy
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
from tianshou.policy.modelfree.iqn import IQNPolicy
from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy
from tianshou.policy.modelfree.npg import NPGPolicy
@ -26,6 +27,7 @@ __all__ = [
"DQNPolicy",
"C51Policy",
"QRDQNPolicy",
"IQNPolicy",
"PGPolicy",
"A2CPolicy",
"NPGPolicy",

View File

@ -0,0 +1,105 @@
import torch
import numpy as np
import torch.nn.functional as F
from typing import Any, Dict, Optional, Union
from tianshou.policy import QRDQNPolicy
from tianshou.data import Batch, to_numpy
class IQNPolicy(QRDQNPolicy):
"""Implementation of Implicit Quantile Network. arXiv:1806.06923.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int sample_size: the number of samples for policy evaluation.
Default to 32.
:param int online_sample_size: the number of samples for online model
in training. Default to 8.
:param int target_sample_size: the number of samples for target model
in training. Default to 8.
:param int estimation_step: the number of steps to look ahead. Default to 1.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
.. seealso::
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
explanation.
"""
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
sample_size: int = 32,
online_sample_size: int = 8,
target_sample_size: int = 8,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(model, optim, discount_factor, sample_size, estimation_step,
target_update_freq, reward_normalization, **kwargs)
assert sample_size > 1, "sample_size should be greater than 1"
assert online_sample_size > 1, "online_sample_size should be greater than 1"
assert target_sample_size > 1, "target_sample_size should be greater than 1"
self._sample_size = sample_size # for policy eval
self._online_sample_size = online_sample_size
self._target_sample_size = target_sample_size
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
**kwargs: Any,
) -> Batch:
if model == "model_old":
sample_size = self._target_sample_size
elif self.training:
sample_size = self._online_sample_size
else:
sample_size = self._sample_size
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
(logits, taus), h = model(
obs_, sample_size=sample_size, state=state, info=batch.info
)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
return Batch(logits=logits, act=act, state=h, taus=taus)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
out = self(batch)
curr_dist, taus = out.logits, out.taus
act = batch.act
curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = (u * (
taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()
).abs()).sum(-1).mean(1)
loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
loss.backward()
self.optim.step()
self._iter += 1
return {"loss": loss.item()}

View File

@ -111,3 +111,94 @@ class Critic(nn.Module):
"""Mapping: s -> V(s)."""
logits, _ = self.preprocess(s, state=kwargs.get("state", None))
return self.last(logits)
class CosineEmbeddingNetwork(nn.Module):
"""Cosine embedding network for IQN. Convert a scalar in [0, 1] to a list \
of n-dim vectors.
:param num_cosines: the number of cosines used for the embedding.
:param embedding_dim: the dimension of the embedding/output.
.. note::
From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
/fqf_iqn_qrdqn/network.py .
"""
def __init__(self, num_cosines: int, embedding_dim: int) -> None:
super().__init__()
self.net = nn.Sequential(nn.Linear(num_cosines, embedding_dim), nn.ReLU())
self.num_cosines = num_cosines
self.embedding_dim = embedding_dim
def forward(self, taus: torch.Tensor) -> torch.Tensor:
batch_size = taus.shape[0]
N = taus.shape[1]
# Calculate i * \pi (i=1,...,N).
i_pi = np.pi * torch.arange(
start=1, end=self.num_cosines + 1, dtype=taus.dtype, device=taus.device
).view(1, 1, self.num_cosines)
# Calculate cos(i * \pi * \tau).
cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi).view(
batch_size * N, self.num_cosines
)
# Calculate embeddings of taus.
tau_embeddings = self.net(cosines).view(batch_size, N, self.embedding_dim)
return tau_embeddings
class ImplicitQuantileNetwork(Critic):
"""Implicit Quantile Network.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param int action_dim: the dimension of action space.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param int num_cosines: the number of cosines to use for cosine embedding.
Default to 64.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
.. note::
Although this class inherits Critic, it is actually a quantile Q-Network
with output shape (batch_size, action_dim, sample_size).
The second item of the first return value is tau vector.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
preprocess_net_output_dim: Optional[int] = None,
device: Union[str, int, torch.device] = "cpu"
) -> None:
last_size = np.prod(action_shape)
super().__init__(preprocess_net, hidden_sizes, last_size,
preprocess_net_output_dim, device)
self.input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.embed_model = CosineEmbeddingNetwork(num_cosines,
self.input_dim).to(device)
def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
) -> Tuple[Any, torch.Tensor]:
r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state=kwargs.get("state", None))
# Sample fractions.
batch_size = logits.size(0)
taus = torch.rand(batch_size, sample_size,
dtype=logits.dtype, device=logits.device)
embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view(
batch_size * sample_size, -1
)
out = self.last(embedding).view(batch_size,
sample_size, -1).transpose(1, 2)
return (out, taus), h