Add Fully-parameterized Quantile Function (#376)
@ -25,6 +25,7 @@
|
||||
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
||||
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
||||
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
|
||||
- [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf)
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
- [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||
|
@ -40,6 +40,11 @@ DQN Family
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.FQFPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
On-policy
|
||||
~~~~~~~~~
|
||||
|
||||
|
@ -15,6 +15,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
||||
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
|
||||
* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function <https://arxiv.org/pdf/1911.02140.pdf>`_
|
||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient <https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf>`_
|
||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
|
@ -68,6 +68,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
| SeaquestNoFrameskip-v4 | 4874 |  | `python3 atari_iqn.py --task "SeaquestNoFrameskip-v4"` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 1498.5 |  | `python3 atari_iqn.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
# FQF (single run)
|
||||
|
||||
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.7 |  | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch-size 64` |
|
||||
| BreakoutNoFrameskip-v4 | 517.3 |  | `python3 atari_fqf.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
|
||||
| EnduroNoFrameskip-v4 | 2240.5 |  | `python3 atari_fqf.py --task "EnduroNoFrameskip-v4"` |
|
||||
| QbertNoFrameskip-v4 | 16172.5 |  | `python3 atari_fqf.py --task "QbertNoFrameskip-v4"` |
|
||||
| MsPacmanNoFrameskip-v4 | 2429 |  | `python3 atari_fqf.py --task "MsPacmanNoFrameskip-v4"` |
|
||||
| SeaquestNoFrameskip-v4 | 10775 |  | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` |
|
||||
| SpaceInvadersNoFrameskip-v4 | 2482 |  | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||
|
||||
# BCQ
|
||||
|
||||
To running BCQ algorithm on Atari, you need to do the following things:
|
||||
|
186
examples/atari/atari_fqf.py
Normal file
@ -0,0 +1,186 @@
|
||||
import os
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import FQFPolicy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
|
||||
|
||||
from atari_network import DQN
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=3128)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=5e-5)
|
||||
parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--num-fractions', type=int, default=32)
|
||||
parser.add_argument('--num-cosines', type=int, default=64)
|
||||
parser.add_argument('--ent-coef', type=float, default=10.)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=32)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_atari_env(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||
|
||||
|
||||
def make_atari_env_watch(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||
episode_life=False, clip_rewards=False)
|
||||
|
||||
|
||||
def test_fqf(args=get_args()):
|
||||
env = make_atari_env(args)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
|
||||
for _ in range(args.training_num)])
|
||||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# define model
|
||||
feature_net = DQN(*args.state_shape, args.action_shape, args.device,
|
||||
features_only=True)
|
||||
net = FullQuantileFunction(
|
||||
feature_net, args.action_shape, args.hidden_sizes,
|
||||
args.num_cosines, device=args.device
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
|
||||
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(),
|
||||
lr=args.fraction_lr)
|
||||
# define policy
|
||||
policy = FQFPolicy(
|
||||
net, optim, fraction_net, fraction_optim,
|
||||
args.gamma, args.num_fractions, args.ent_coef, args.n_step,
|
||||
target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||
# when you have enough RAM
|
||||
buffer = VectorReplayBuffer(
|
||||
args.buffer_size, buffer_num=len(train_envs),
|
||||
ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'fqf')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = args.eps_train - env_step / 1e6 * \
|
||||
(args.eps_train - args.eps_train_final)
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
print(f"Generate buffer with size {args.buffer_size}")
|
||||
buffer = VectorReplayBuffer(
|
||||
args.buffer_size, buffer_num=len(test_envs),
|
||||
ignore_obs_next=True, save_only_last_obs=True,
|
||||
stack_num=args.frames_stack)
|
||||
collector = Collector(policy, test_envs, buffer,
|
||||
exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num,
|
||||
render=args.render)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step, test_in_train=False)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fqf(get_args())
|
BIN
examples/atari/results/fqf/Breakout_rew.png
Normal file
After Width: | Height: | Size: 215 KiB |
BIN
examples/atari/results/fqf/Enduro_rew.png
Normal file
After Width: | Height: | Size: 188 KiB |
BIN
examples/atari/results/fqf/MsPacman_rew.png
Normal file
After Width: | Height: | Size: 200 KiB |
BIN
examples/atari/results/fqf/Pong_rew.png
Normal file
After Width: | Height: | Size: 140 KiB |
BIN
examples/atari/results/fqf/Qbert_rew.png
Normal file
After Width: | Height: | Size: 194 KiB |
BIN
examples/atari/results/fqf/Seaquest_rew.png
Normal file
After Width: | Height: | Size: 201 KiB |
BIN
examples/atari/results/fqf/SpaceInvaders_rew.png
Normal file
After Width: | Height: | Size: 215 KiB |
153
test/discrete/test_fqf.py
Normal file
@ -0,0 +1,153 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import FQFPolicy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
|
||||
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-3)
|
||||
parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--num-fractions', type=int, default=32)
|
||||
parser.add_argument('--num-cosines', type=int, default=64)
|
||||
parser.add_argument('--ent-coef', type=float, default=10.)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[64, 64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay',
|
||||
action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_fqf(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
feature_net = Net(args.state_shape, args.hidden_sizes[-1],
|
||||
hidden_sizes=args.hidden_sizes[:-1], device=args.device,
|
||||
softmax=False)
|
||||
net = FullQuantileFunction(
|
||||
feature_net, args.action_shape, args.hidden_sizes,
|
||||
num_cosines=args.num_cosines, device=args.device
|
||||
)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
|
||||
fraction_optim = torch.optim.RMSprop(
|
||||
fraction_net.parameters(), lr=args.fraction_lr
|
||||
)
|
||||
policy = FQFPolicy(
|
||||
net, optim, fraction_net, fraction_optim, args.gamma, args.num_fractions,
|
||||
args.ent_coef, args.n_step, target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# buffer
|
||||
if args.prioritized_replay:
|
||||
buf = PrioritizedVectorReplayBuffer(
|
||||
args.buffer_size, buffer_num=len(train_envs),
|
||||
alpha=args.alpha, beta=args.beta)
|
||||
else:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'fqf')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# eps annnealing, just a demo
|
||||
if env_step <= 10000:
|
||||
policy.set_eps(args.eps_train)
|
||||
elif env_step <= 50000:
|
||||
eps = args.eps_train - (env_step - 10000) / \
|
||||
40000 * (0.9 * args.eps_train)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(0.1 * args.eps_train)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
def test_pfqf(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
args.gamma = .95
|
||||
test_fqf(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fqf(get_args())
|
@ -4,6 +4,7 @@ from tianshou.policy.modelfree.dqn import DQNPolicy
|
||||
from tianshou.policy.modelfree.c51 import C51Policy
|
||||
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
||||
from tianshou.policy.modelfree.iqn import IQNPolicy
|
||||
from tianshou.policy.modelfree.fqf import FQFPolicy
|
||||
from tianshou.policy.modelfree.pg import PGPolicy
|
||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||
from tianshou.policy.modelfree.npg import NPGPolicy
|
||||
@ -28,6 +29,7 @@ __all__ = [
|
||||
"C51Policy",
|
||||
"QRDQNPolicy",
|
||||
"IQNPolicy",
|
||||
"FQFPolicy",
|
||||
"PGPolicy",
|
||||
"A2CPolicy",
|
||||
"NPGPolicy",
|
||||
|
161
tianshou/policy/modelfree/fqf.py
Normal file
@ -0,0 +1,161 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from tianshou.policy import DQNPolicy, QRDQNPolicy
|
||||
from tianshou.data import Batch, to_numpy, ReplayBuffer
|
||||
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
|
||||
|
||||
|
||||
class FQFPolicy(QRDQNPolicy):
|
||||
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param FractionProposalNetwork fraction_model: a FractionProposalNetwork for
|
||||
proposing fractions/quantiles given state.
|
||||
:param torch.optim.Optimizer fraction_optim: a torch.optim for optimizing
|
||||
the fraction model above.
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param int num_fractions: the number of fractions to use. Default to 32.
|
||||
:param float ent_coef: the coefficient for entropy loss. Default to 0.
|
||||
:param int estimation_step: the number of steps to look ahead. Default to 1.
|
||||
:param int target_update_freq: the target network update frequency (0 if
|
||||
you do not use the target network).
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: FullQuantileFunction,
|
||||
optim: torch.optim.Optimizer,
|
||||
fraction_model: FractionProposalNetwork,
|
||||
fraction_optim: torch.optim.Optimizer,
|
||||
discount_factor: float = 0.99,
|
||||
num_fractions: int = 32,
|
||||
ent_coef: float = 0.0,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model, optim, discount_factor, num_fractions, estimation_step,
|
||||
target_update_freq, reward_normalization, **kwargs
|
||||
)
|
||||
self.propose_model = fraction_model
|
||||
self._ent_coef = ent_coef
|
||||
self._fraction_optim = fraction_optim
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
if self._target:
|
||||
result = self(batch, input="obs_next")
|
||||
a, fractions = result.act, result.fractions
|
||||
next_dist = self(
|
||||
batch, model="model_old", input="obs_next", fractions=fractions
|
||||
).logits
|
||||
else:
|
||||
next_b = self(batch, input="obs_next")
|
||||
a = next_b.act
|
||||
next_dist = next_b.logits
|
||||
next_dist = next_dist[np.arange(len(a)), a, :]
|
||||
return next_dist # shape: [bsz, num_quantiles]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
fractions: Optional[Batch] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
if fractions is None:
|
||||
(logits, fractions, quantiles_tau), h = model(
|
||||
obs_, propose_model=self.propose_model, state=state, info=batch.info
|
||||
)
|
||||
else:
|
||||
(logits, _, quantiles_tau), h = model(
|
||||
obs_, propose_model=self.propose_model, fractions=fractions,
|
||||
state=state, info=batch.info
|
||||
)
|
||||
weighted_logits = (
|
||||
fractions.taus[:, 1:] - fractions.taus[:, :-1]
|
||||
).unsqueeze(1) * logits
|
||||
q = DQNPolicy.compute_q_value(
|
||||
self, weighted_logits.sum(2), getattr(obs, "mask", None)
|
||||
)
|
||||
if not hasattr(self, "max_action_num"):
|
||||
self.max_action_num = q.shape[1]
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
return Batch(
|
||||
logits=logits, act=act, state=h, fractions=fractions,
|
||||
quantiles_tau=quantiles_tau
|
||||
)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
weight = batch.pop("weight", 1.0)
|
||||
out = self(batch)
|
||||
curr_dist_orig = out.logits
|
||||
taus, tau_hats = out.fractions.taus, out.fractions.tau_hats
|
||||
act = batch.act
|
||||
curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
|
||||
target_dist = batch.returns.unsqueeze(1)
|
||||
# calculate each element's difference between curr_dist and target_dist
|
||||
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
huber_loss = (u * (
|
||||
tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()
|
||||
).abs()).sum(-1).mean(1)
|
||||
quantile_loss = (huber_loss * weight).mean()
|
||||
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
|
||||
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
# calculate fraction loss
|
||||
with torch.no_grad():
|
||||
sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]
|
||||
sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :]
|
||||
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||
# blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169
|
||||
values_1 = sa_quantiles - sa_quantile_hats[:, :-1]
|
||||
signs_1 = sa_quantiles > torch.cat([
|
||||
sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1)
|
||||
|
||||
values_2 = sa_quantiles - sa_quantile_hats[:, 1:]
|
||||
signs_2 = sa_quantiles < torch.cat([
|
||||
sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1)
|
||||
|
||||
gradient_of_taus = (
|
||||
torch.where(signs_1, values_1, -values_1)
|
||||
+ torch.where(signs_2, values_2, -values_2)
|
||||
)
|
||||
fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean()
|
||||
# calculate entropy loss
|
||||
entropy_loss = out.fractions.entropies.mean()
|
||||
fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss
|
||||
self._fraction_optim.zero_grad()
|
||||
fraction_entropy_loss.backward(retain_graph=True)
|
||||
self._fraction_optim.step()
|
||||
self.optim.zero_grad()
|
||||
quantile_loss.backward()
|
||||
self.optim.step()
|
||||
self._iter += 1
|
||||
return {
|
||||
"loss": quantile_loss.item() + fraction_entropy_loss.item(),
|
||||
"loss/quantile": quantile_loss.item(),
|
||||
"loss/fraction": fraction_loss.item(),
|
||||
"loss/entropy": entropy_loss.item()
|
||||
}
|
@ -4,6 +4,7 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.utils.net.common import MLP
|
||||
|
||||
|
||||
@ -199,6 +200,110 @@ class ImplicitQuantileNetwork(Critic):
|
||||
embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view(
|
||||
batch_size * sample_size, -1
|
||||
)
|
||||
out = self.last(embedding).view(batch_size,
|
||||
sample_size, -1).transpose(1, 2)
|
||||
out = self.last(embedding).view(
|
||||
batch_size, sample_size, -1).transpose(1, 2)
|
||||
return (out, taus), h
|
||||
|
||||
|
||||
class FractionProposalNetwork(nn.Module):
|
||||
"""Fraction proposal network for FQF.
|
||||
|
||||
:param num_fractions: the number of factions to propose.
|
||||
:param embedding_dim: the dimension of the embedding/input.
|
||||
|
||||
.. note::
|
||||
|
||||
Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
|
||||
/fqf_iqn_qrdqn/network.py .
|
||||
"""
|
||||
|
||||
def __init__(self, num_fractions: int, embedding_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.net = nn.Linear(embedding_dim, num_fractions)
|
||||
torch.nn.init.xavier_uniform_(self.net.weight, gain=0.01)
|
||||
torch.nn.init.constant_(self.net.bias, 0)
|
||||
self.num_fractions = num_fractions
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(
|
||||
self, state_embeddings: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Calculate (log of) probabilities q_i in the paper.
|
||||
m = torch.distributions.Categorical(logits=self.net(state_embeddings))
|
||||
taus_1_N = torch.cumsum(m.probs, dim=1)
|
||||
# Calculate \tau_i (i=0,...,N).
|
||||
taus = F.pad(taus_1_N, (1, 0))
|
||||
# Calculate \hat \tau_i (i=0,...,N-1).
|
||||
tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0
|
||||
# Calculate entropies of value distributions.
|
||||
entropies = m.entropy()
|
||||
return taus, tau_hats, entropies
|
||||
|
||||
|
||||
class FullQuantileFunction(ImplicitQuantileNetwork):
|
||||
"""Full(y parameterized) Quantile Function.
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param int action_dim: the dimension of action space.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param int num_cosines: the number of cosines to use for cosine embedding.
|
||||
Default to 64.
|
||||
:param int preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
|
||||
.. note::
|
||||
|
||||
The first return value is a tuple of (quantiles, fractions, quantiles_tau),
|
||||
where fractions is a Batch(taus, tau_hats, entropies).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
num_cosines: int = 64,
|
||||
preprocess_net_output_dim: Optional[int] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
preprocess_net, action_shape, hidden_sizes,
|
||||
num_cosines, preprocess_net_output_dim, device
|
||||
)
|
||||
|
||||
def _compute_quantiles(
|
||||
self, obs: torch.Tensor, taus: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
batch_size, sample_size = taus.shape
|
||||
embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view(
|
||||
batch_size * sample_size, -1
|
||||
)
|
||||
quantiles = self.last(embedding).view(
|
||||
batch_size, sample_size, -1
|
||||
).transpose(1, 2)
|
||||
return quantiles
|
||||
|
||||
def forward( # type: ignore
|
||||
self, s: Union[np.ndarray, torch.Tensor],
|
||||
propose_model: FractionProposalNetwork,
|
||||
fractions: Optional[Batch] = None,
|
||||
**kwargs: Any
|
||||
) -> Tuple[Any, torch.Tensor]:
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
logits, h = self.preprocess(s, state=kwargs.get("state", None))
|
||||
# Propose fractions
|
||||
if fractions is None:
|
||||
taus, tau_hats, entropies = propose_model(logits.detach())
|
||||
fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies)
|
||||
else:
|
||||
taus, tau_hats = fractions.taus, fractions.tau_hats
|
||||
quantiles = self._compute_quantiles(logits, tau_hats)
|
||||
# Calculate quantiles_tau for computing fraction grad
|
||||
quantiles_tau = None
|
||||
if self.training:
|
||||
with torch.no_grad():
|
||||
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
|
||||
return (quantiles, fractions, quantiles_tau), h
|
||||
|