Add Fully-parameterized Quantile Function (#376)

This commit is contained in:
Yi Su 2021-06-14 20:59:02 -07:00 committed by GitHub
parent 21b2b22cd7
commit c0bc8e00ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 630 additions and 2 deletions

View File

@ -25,6 +25,7 @@
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
- [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf)
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)

View File

@ -40,6 +40,11 @@ DQN Family
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.FQFPolicy
:members:
:undoc-members:
:show-inheritance:
On-policy
~~~~~~~~~

View File

@ -15,6 +15,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function <https://arxiv.org/pdf/1911.02140.pdf>`_
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient <https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_

View File

@ -68,6 +68,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| SeaquestNoFrameskip-v4 | 4874 | ![](results/iqn/Seaquest_rew.png) | `python3 atari_iqn.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 1498.5 | ![](results/iqn/SpaceInvaders_rew.png) | `python3 atari_iqn.py --task "SpaceInvadersNoFrameskip-v4"` |
# FQF (single run)
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.7 | ![](results/fqf/Pong_rew.png) | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch-size 64` |
| BreakoutNoFrameskip-v4 | 517.3 | ![](results/fqf/Breakout_rew.png) | `python3 atari_fqf.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
| EnduroNoFrameskip-v4 | 2240.5 | ![](results/fqf/Enduro_rew.png) | `python3 atari_fqf.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 16172.5 | ![](results/fqf/Qbert_rew.png) | `python3 atari_fqf.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 2429 | ![](results/fqf/MsPacman_rew.png) | `python3 atari_fqf.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 10775 | ![](results/fqf/Seaquest_rew.png) | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 2482 | ![](results/fqf/SpaceInvaders_rew.png) | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` |
# BCQ
To running BCQ algorithm on Atari, you need to do the following things:

186
examples/atari/atari_fqf.py Normal file
View File

@ -0,0 +1,186 @@
import os
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import FQFPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
from atari_network import DQN
from atari_wrapper import wrap_deepmind
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=3128)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-fractions', type=int, default=32)
parser.add_argument('--num-cosines', type=int, default=64)
parser.add_argument('--ent-coef', type=float, default=10.)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)
def test_fqf(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
feature_net = DQN(*args.state_shape, args.action_shape, args.device,
features_only=True)
net = FullQuantileFunction(
feature_net, args.action_shape, args.hidden_sizes,
args.num_cosines, device=args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(),
lr=args.fraction_lr)
# define policy
policy = FQFPolicy(
net, optim, fraction_net, fraction_optim,
args.gamma, args.num_fractions, args.ent_coef, args.n_step,
target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs),
ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'fqf')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
if args.watch:
watch()
exit(0)
# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result)
watch()
if __name__ == '__main__':
test_fqf(get_args())

Binary file not shown.

After

Width:  |  Height:  |  Size: 215 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 188 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 200 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 194 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 201 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 215 KiB

153
test/discrete/test_fqf.py Normal file
View File

@ -0,0 +1,153 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import FQFPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-3)
parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--num-fractions', type=int, default=32)
parser.add_argument('--num-cosines', type=int, default=64)
parser.add_argument('--ent-coef', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[64, 64, 64])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay',
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_fqf(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
feature_net = Net(args.state_shape, args.hidden_sizes[-1],
hidden_sizes=args.hidden_sizes[:-1], device=args.device,
softmax=False)
net = FullQuantileFunction(
feature_net, args.action_shape, args.hidden_sizes,
num_cosines=args.num_cosines, device=args.device
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
fraction_optim = torch.optim.RMSprop(
fraction_net.parameters(), lr=args.fraction_lr
)
policy = FQFPolicy(
net, optim, fraction_net, fraction_optim, args.gamma, args.num_fractions,
args.ent_coef, args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs),
alpha=args.alpha, beta=args.beta)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'fqf')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(epoch, env_step):
# eps annnealing, just a demo
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_pfqf(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
test_fqf(args)
if __name__ == '__main__':
test_fqf(get_args())

View File

@ -4,6 +4,7 @@ from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.c51 import C51Policy
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
from tianshou.policy.modelfree.iqn import IQNPolicy
from tianshou.policy.modelfree.fqf import FQFPolicy
from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy
from tianshou.policy.modelfree.npg import NPGPolicy
@ -28,6 +29,7 @@ __all__ = [
"C51Policy",
"QRDQNPolicy",
"IQNPolicy",
"FQFPolicy",
"PGPolicy",
"A2CPolicy",
"NPGPolicy",

View File

@ -0,0 +1,161 @@
import torch
import numpy as np
import torch.nn.functional as F
from typing import Any, Dict, Optional, Union
from tianshou.policy import DQNPolicy, QRDQNPolicy
from tianshou.data import Batch, to_numpy, ReplayBuffer
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
class FQFPolicy(QRDQNPolicy):
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param FractionProposalNetwork fraction_model: a FractionProposalNetwork for
proposing fractions/quantiles given state.
:param torch.optim.Optimizer fraction_optim: a torch.optim for optimizing
the fraction model above.
:param float discount_factor: in [0, 1].
:param int num_fractions: the number of fractions to use. Default to 32.
:param float ent_coef: the coefficient for entropy loss. Default to 0.
:param int estimation_step: the number of steps to look ahead. Default to 1.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
.. seealso::
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
explanation.
"""
def __init__(
self,
model: FullQuantileFunction,
optim: torch.optim.Optimizer,
fraction_model: FractionProposalNetwork,
fraction_optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
num_fractions: int = 32,
ent_coef: float = 0.0,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
model, optim, discount_factor, num_fractions, estimation_step,
target_update_freq, reward_normalization, **kwargs
)
self.propose_model = fraction_model
self._ent_coef = ent_coef
self._fraction_optim = fraction_optim
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
if self._target:
result = self(batch, input="obs_next")
a, fractions = result.act, result.fractions
next_dist = self(
batch, model="model_old", input="obs_next", fractions=fractions
).logits
else:
next_b = self(batch, input="obs_next")
a = next_b.act
next_dist = next_b.logits
next_dist = next_dist[np.arange(len(a)), a, :]
return next_dist # shape: [bsz, num_quantiles]
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
fractions: Optional[Batch] = None,
**kwargs: Any,
) -> Batch:
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
if fractions is None:
(logits, fractions, quantiles_tau), h = model(
obs_, propose_model=self.propose_model, state=state, info=batch.info
)
else:
(logits, _, quantiles_tau), h = model(
obs_, propose_model=self.propose_model, fractions=fractions,
state=state, info=batch.info
)
weighted_logits = (
fractions.taus[:, 1:] - fractions.taus[:, :-1]
).unsqueeze(1) * logits
q = DQNPolicy.compute_q_value(
self, weighted_logits.sum(2), getattr(obs, "mask", None)
)
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
return Batch(
logits=logits, act=act, state=h, fractions=fractions,
quantiles_tau=quantiles_tau
)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
weight = batch.pop("weight", 1.0)
out = self(batch)
curr_dist_orig = out.logits
taus, tau_hats = out.fractions.taus, out.fractions.tau_hats
act = batch.act
curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = (u * (
tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()
).abs()).sum(-1).mean(1)
quantile_loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
# calculate fraction loss
with torch.no_grad():
sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]
sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :]
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169
values_1 = sa_quantiles - sa_quantile_hats[:, :-1]
signs_1 = sa_quantiles > torch.cat([
sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1)
values_2 = sa_quantiles - sa_quantile_hats[:, 1:]
signs_2 = sa_quantiles < torch.cat([
sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1)
gradient_of_taus = (
torch.where(signs_1, values_1, -values_1)
+ torch.where(signs_2, values_2, -values_2)
)
fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean()
# calculate entropy loss
entropy_loss = out.fractions.entropies.mean()
fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss
self._fraction_optim.zero_grad()
fraction_entropy_loss.backward(retain_graph=True)
self._fraction_optim.step()
self.optim.zero_grad()
quantile_loss.backward()
self.optim.step()
self._iter += 1
return {
"loss": quantile_loss.item() + fraction_entropy_loss.item(),
"loss/quantile": quantile_loss.item(),
"loss/fraction": fraction_loss.item(),
"loss/entropy": entropy_loss.item()
}

View File

@ -4,6 +4,7 @@ from torch import nn
import torch.nn.functional as F
from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.data import Batch
from tianshou.utils.net.common import MLP
@ -199,6 +200,110 @@ class ImplicitQuantileNetwork(Critic):
embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view(
batch_size * sample_size, -1
)
out = self.last(embedding).view(batch_size,
sample_size, -1).transpose(1, 2)
out = self.last(embedding).view(
batch_size, sample_size, -1).transpose(1, 2)
return (out, taus), h
class FractionProposalNetwork(nn.Module):
"""Fraction proposal network for FQF.
:param num_fractions: the number of factions to propose.
:param embedding_dim: the dimension of the embedding/input.
.. note::
Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
/fqf_iqn_qrdqn/network.py .
"""
def __init__(self, num_fractions: int, embedding_dim: int) -> None:
super().__init__()
self.net = nn.Linear(embedding_dim, num_fractions)
torch.nn.init.xavier_uniform_(self.net.weight, gain=0.01)
torch.nn.init.constant_(self.net.bias, 0)
self.num_fractions = num_fractions
self.embedding_dim = embedding_dim
def forward(
self, state_embeddings: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Calculate (log of) probabilities q_i in the paper.
m = torch.distributions.Categorical(logits=self.net(state_embeddings))
taus_1_N = torch.cumsum(m.probs, dim=1)
# Calculate \tau_i (i=0,...,N).
taus = F.pad(taus_1_N, (1, 0))
# Calculate \hat \tau_i (i=0,...,N-1).
tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0
# Calculate entropies of value distributions.
entropies = m.entropy()
return taus, tau_hats, entropies
class FullQuantileFunction(ImplicitQuantileNetwork):
"""Full(y parameterized) Quantile Function.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param int action_dim: the dimension of action space.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param int num_cosines: the number of cosines to use for cosine embedding.
Default to 64.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
.. note::
The first return value is a tuple of (quantiles, fractions, quantiles_tau),
where fractions is a Batch(taus, tau_hats, entropies).
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
preprocess_net_output_dim: Optional[int] = None,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__(
preprocess_net, action_shape, hidden_sizes,
num_cosines, preprocess_net_output_dim, device
)
def _compute_quantiles(
self, obs: torch.Tensor, taus: torch.Tensor
) -> torch.Tensor:
batch_size, sample_size = taus.shape
embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view(
batch_size * sample_size, -1
)
quantiles = self.last(embedding).view(
batch_size, sample_size, -1
).transpose(1, 2)
return quantiles
def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor],
propose_model: FractionProposalNetwork,
fractions: Optional[Batch] = None,
**kwargs: Any
) -> Tuple[Any, torch.Tensor]:
r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state=kwargs.get("state", None))
# Propose fractions
if fractions is None:
taus, tau_hats, entropies = propose_model(logits.detach())
fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies)
else:
taus, tau_hats = fractions.taus, fractions.tau_hats
quantiles = self._compute_quantiles(logits, tau_hats)
# Calculate quantiles_tau for computing fraction grad
quantiles_tau = None
if self.training:
with torch.no_grad():
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
return (quantiles, fractions, quantiles_tau), h