Add Fully-parameterized Quantile Function (#376)
@ -25,6 +25,7 @@
|
|||||||
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
||||||
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
||||||
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
|
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
|
||||||
|
- [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf)
|
||||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||||
- [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf)
|
- [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf)
|
||||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||||
|
|||||||
@ -40,6 +40,11 @@ DQN Family
|
|||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.policy.FQFPolicy
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
On-policy
|
On-policy
|
||||||
~~~~~~~~~
|
~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@ -15,6 +15,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||||
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
||||||
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
|
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
|
||||||
|
* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function <https://arxiv.org/pdf/1911.02140.pdf>`_
|
||||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||||
* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient <https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf>`_
|
* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient <https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf>`_
|
||||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||||
|
|||||||
@ -68,6 +68,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
|||||||
| SeaquestNoFrameskip-v4 | 4874 |  | `python3 atari_iqn.py --task "SeaquestNoFrameskip-v4"` |
|
| SeaquestNoFrameskip-v4 | 4874 |  | `python3 atari_iqn.py --task "SeaquestNoFrameskip-v4"` |
|
||||||
| SpaceInvadersNoFrameskip-v4 | 1498.5 |  | `python3 atari_iqn.py --task "SpaceInvadersNoFrameskip-v4"` |
|
| SpaceInvadersNoFrameskip-v4 | 1498.5 |  | `python3 atari_iqn.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||||
|
|
||||||
|
# FQF (single run)
|
||||||
|
|
||||||
|
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||||
|
|
||||||
|
| task | best reward | reward curve | parameters |
|
||||||
|
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||||
|
| PongNoFrameskip-v4 | 20.7 |  | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch-size 64` |
|
||||||
|
| BreakoutNoFrameskip-v4 | 517.3 |  | `python3 atari_fqf.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
|
||||||
|
| EnduroNoFrameskip-v4 | 2240.5 |  | `python3 atari_fqf.py --task "EnduroNoFrameskip-v4"` |
|
||||||
|
| QbertNoFrameskip-v4 | 16172.5 |  | `python3 atari_fqf.py --task "QbertNoFrameskip-v4"` |
|
||||||
|
| MsPacmanNoFrameskip-v4 | 2429 |  | `python3 atari_fqf.py --task "MsPacmanNoFrameskip-v4"` |
|
||||||
|
| SeaquestNoFrameskip-v4 | 10775 |  | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` |
|
||||||
|
| SpaceInvadersNoFrameskip-v4 | 2482 |  | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` |
|
||||||
|
|
||||||
# BCQ
|
# BCQ
|
||||||
|
|
||||||
To running BCQ algorithm on Atari, you need to do the following things:
|
To running BCQ algorithm on Atari, you need to do the following things:
|
||||||
|
|||||||
186
examples/atari/atari_fqf.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import pprint
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.policy import FQFPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
|
||||||
|
|
||||||
|
from atari_network import DQN
|
||||||
|
from atari_wrapper import wrap_deepmind
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||||
|
parser.add_argument('--seed', type=int, default=3128)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=1.)
|
||||||
|
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||||
|
parser.add_argument('--lr', type=float, default=5e-5)
|
||||||
|
parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
|
parser.add_argument('--num-fractions', type=int, default=32)
|
||||||
|
parser.add_argument('--num-cosines', type=int, default=64)
|
||||||
|
parser.add_argument('--ent-coef', type=float, default=10.)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||||
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=32)
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
parser.add_argument('--frames-stack', type=int, default=4)
|
||||||
|
parser.add_argument('--resume-path', type=str, default=None)
|
||||||
|
parser.add_argument('--watch', default=False, action='store_true',
|
||||||
|
help='watch the play of pre-trained policy only')
|
||||||
|
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def make_atari_env(args):
|
||||||
|
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||||
|
|
||||||
|
|
||||||
|
def make_atari_env_watch(args):
|
||||||
|
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||||
|
episode_life=False, clip_rewards=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fqf(args=get_args()):
|
||||||
|
env = make_atari_env(args)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# should be N_FRAMES x H x W
|
||||||
|
print("Observations shape:", args.state_shape)
|
||||||
|
print("Actions shape:", args.action_shape)
|
||||||
|
# make environments
|
||||||
|
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
|
||||||
|
for _ in range(args.training_num)])
|
||||||
|
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||||
|
for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# define model
|
||||||
|
feature_net = DQN(*args.state_shape, args.action_shape, args.device,
|
||||||
|
features_only=True)
|
||||||
|
net = FullQuantileFunction(
|
||||||
|
feature_net, args.action_shape, args.hidden_sizes,
|
||||||
|
args.num_cosines, device=args.device
|
||||||
|
).to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
|
||||||
|
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(),
|
||||||
|
lr=args.fraction_lr)
|
||||||
|
# define policy
|
||||||
|
policy = FQFPolicy(
|
||||||
|
net, optim, fraction_net, fraction_optim,
|
||||||
|
args.gamma, args.num_fractions, args.ent_coef, args.n_step,
|
||||||
|
target_update_freq=args.target_update_freq
|
||||||
|
).to(args.device)
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||||
|
# when you have enough RAM
|
||||||
|
buffer = VectorReplayBuffer(
|
||||||
|
args.buffer_size, buffer_num=len(train_envs),
|
||||||
|
ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack)
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'fqf')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
if env.spec.reward_threshold:
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
elif 'Pong' in args.task:
|
||||||
|
return mean_rewards >= 20
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def train_fn(epoch, env_step):
|
||||||
|
# nature DQN setting, linear decay in the first 1M steps
|
||||||
|
if env_step <= 1e6:
|
||||||
|
eps = args.eps_train - env_step / 1e6 * \
|
||||||
|
(args.eps_train - args.eps_train_final)
|
||||||
|
else:
|
||||||
|
eps = args.eps_train_final
|
||||||
|
policy.set_eps(eps)
|
||||||
|
logger.write('train/eps', env_step, eps)
|
||||||
|
|
||||||
|
def test_fn(epoch, env_step):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
# watch agent's performance
|
||||||
|
def watch():
|
||||||
|
print("Setup test envs ...")
|
||||||
|
policy.eval()
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
if args.save_buffer_name:
|
||||||
|
print(f"Generate buffer with size {args.buffer_size}")
|
||||||
|
buffer = VectorReplayBuffer(
|
||||||
|
args.buffer_size, buffer_num=len(test_envs),
|
||||||
|
ignore_obs_next=True, save_only_last_obs=True,
|
||||||
|
stack_num=args.frames_stack)
|
||||||
|
collector = Collector(policy, test_envs, buffer,
|
||||||
|
exploration_noise=True)
|
||||||
|
result = collector.collect(n_step=args.buffer_size)
|
||||||
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
|
else:
|
||||||
|
print("Testing agent ...")
|
||||||
|
test_collector.reset()
|
||||||
|
result = test_collector.collect(n_episode=args.test_num,
|
||||||
|
render=args.render)
|
||||||
|
rew = result["rews"].mean()
|
||||||
|
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||||
|
|
||||||
|
if args.watch:
|
||||||
|
watch()
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
# test train_collector and start filling replay buffer
|
||||||
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
|
|
||||||
|
pprint.pprint(result)
|
||||||
|
watch()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_fqf(get_args())
|
||||||
BIN
examples/atari/results/fqf/Breakout_rew.png
Normal file
|
After Width: | Height: | Size: 215 KiB |
BIN
examples/atari/results/fqf/Enduro_rew.png
Normal file
|
After Width: | Height: | Size: 188 KiB |
BIN
examples/atari/results/fqf/MsPacman_rew.png
Normal file
|
After Width: | Height: | Size: 200 KiB |
BIN
examples/atari/results/fqf/Pong_rew.png
Normal file
|
After Width: | Height: | Size: 140 KiB |
BIN
examples/atari/results/fqf/Qbert_rew.png
Normal file
|
After Width: | Height: | Size: 194 KiB |
BIN
examples/atari/results/fqf/Seaquest_rew.png
Normal file
|
After Width: | Height: | Size: 201 KiB |
BIN
examples/atari/results/fqf/SpaceInvaders_rew.png
Normal file
|
After Width: | Height: | Size: 215 KiB |
153
test/discrete/test_fqf.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
import os
|
||||||
|
import gym
|
||||||
|
import torch
|
||||||
|
import pprint
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.policy import FQFPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=3e-3)
|
||||||
|
parser.add_argument('--fraction-lr', type=float, default=2.5e-9)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
|
parser.add_argument('--num-fractions', type=int, default=32)
|
||||||
|
parser.add_argument('--num-cosines', type=int, default=64)
|
||||||
|
parser.add_argument('--ent-coef', type=float, default=10.)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int,
|
||||||
|
nargs='*', default=[64, 64, 64])
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument('--prioritized-replay',
|
||||||
|
action="store_true", default=False)
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.6)
|
||||||
|
parser.add_argument('--beta', type=float, default=0.4)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_fqf(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
train_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
feature_net = Net(args.state_shape, args.hidden_sizes[-1],
|
||||||
|
hidden_sizes=args.hidden_sizes[:-1], device=args.device,
|
||||||
|
softmax=False)
|
||||||
|
net = FullQuantileFunction(
|
||||||
|
feature_net, args.action_shape, args.hidden_sizes,
|
||||||
|
num_cosines=args.num_cosines, device=args.device
|
||||||
|
)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
|
||||||
|
fraction_optim = torch.optim.RMSprop(
|
||||||
|
fraction_net.parameters(), lr=args.fraction_lr
|
||||||
|
)
|
||||||
|
policy = FQFPolicy(
|
||||||
|
net, optim, fraction_net, fraction_optim, args.gamma, args.num_fractions,
|
||||||
|
args.ent_coef, args.n_step, target_update_freq=args.target_update_freq
|
||||||
|
).to(args.device)
|
||||||
|
# buffer
|
||||||
|
if args.prioritized_replay:
|
||||||
|
buf = PrioritizedVectorReplayBuffer(
|
||||||
|
args.buffer_size, buffer_num=len(train_envs),
|
||||||
|
alpha=args.alpha, beta=args.beta)
|
||||||
|
else:
|
||||||
|
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
# policy.set_eps(1)
|
||||||
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'fqf')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
def train_fn(epoch, env_step):
|
||||||
|
# eps annnealing, just a demo
|
||||||
|
if env_step <= 10000:
|
||||||
|
policy.set_eps(args.eps_train)
|
||||||
|
elif env_step <= 50000:
|
||||||
|
eps = args.eps_train - (env_step - 10000) / \
|
||||||
|
40000 * (0.9 * args.eps_train)
|
||||||
|
policy.set_eps(eps)
|
||||||
|
else:
|
||||||
|
policy.set_eps(0.1 * args.eps_train)
|
||||||
|
|
||||||
|
def test_fn(epoch, env_step):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
|
update_per_step=args.update_per_step)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
policy.eval()
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_pfqf(args=get_args()):
|
||||||
|
args.prioritized_replay = True
|
||||||
|
args.gamma = .95
|
||||||
|
test_fqf(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_fqf(get_args())
|
||||||
@ -4,6 +4,7 @@ from tianshou.policy.modelfree.dqn import DQNPolicy
|
|||||||
from tianshou.policy.modelfree.c51 import C51Policy
|
from tianshou.policy.modelfree.c51 import C51Policy
|
||||||
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
||||||
from tianshou.policy.modelfree.iqn import IQNPolicy
|
from tianshou.policy.modelfree.iqn import IQNPolicy
|
||||||
|
from tianshou.policy.modelfree.fqf import FQFPolicy
|
||||||
from tianshou.policy.modelfree.pg import PGPolicy
|
from tianshou.policy.modelfree.pg import PGPolicy
|
||||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||||
from tianshou.policy.modelfree.npg import NPGPolicy
|
from tianshou.policy.modelfree.npg import NPGPolicy
|
||||||
@ -28,6 +29,7 @@ __all__ = [
|
|||||||
"C51Policy",
|
"C51Policy",
|
||||||
"QRDQNPolicy",
|
"QRDQNPolicy",
|
||||||
"IQNPolicy",
|
"IQNPolicy",
|
||||||
|
"FQFPolicy",
|
||||||
"PGPolicy",
|
"PGPolicy",
|
||||||
"A2CPolicy",
|
"A2CPolicy",
|
||||||
"NPGPolicy",
|
"NPGPolicy",
|
||||||
|
|||||||
161
tianshou/policy/modelfree/fqf.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
from tianshou.policy import DQNPolicy, QRDQNPolicy
|
||||||
|
from tianshou.data import Batch, to_numpy, ReplayBuffer
|
||||||
|
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
|
||||||
|
|
||||||
|
|
||||||
|
class FQFPolicy(QRDQNPolicy):
|
||||||
|
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
|
||||||
|
|
||||||
|
:param torch.nn.Module model: a model following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||||
|
:param FractionProposalNetwork fraction_model: a FractionProposalNetwork for
|
||||||
|
proposing fractions/quantiles given state.
|
||||||
|
:param torch.optim.Optimizer fraction_optim: a torch.optim for optimizing
|
||||||
|
the fraction model above.
|
||||||
|
:param float discount_factor: in [0, 1].
|
||||||
|
:param int num_fractions: the number of fractions to use. Default to 32.
|
||||||
|
:param float ent_coef: the coefficient for entropy loss. Default to 0.
|
||||||
|
:param int estimation_step: the number of steps to look ahead. Default to 1.
|
||||||
|
:param int target_update_freq: the target network update frequency (0 if
|
||||||
|
you do not use the target network).
|
||||||
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
|
Default to False.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: FullQuantileFunction,
|
||||||
|
optim: torch.optim.Optimizer,
|
||||||
|
fraction_model: FractionProposalNetwork,
|
||||||
|
fraction_optim: torch.optim.Optimizer,
|
||||||
|
discount_factor: float = 0.99,
|
||||||
|
num_fractions: int = 32,
|
||||||
|
ent_coef: float = 0.0,
|
||||||
|
estimation_step: int = 1,
|
||||||
|
target_update_freq: int = 0,
|
||||||
|
reward_normalization: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
model, optim, discount_factor, num_fractions, estimation_step,
|
||||||
|
target_update_freq, reward_normalization, **kwargs
|
||||||
|
)
|
||||||
|
self.propose_model = fraction_model
|
||||||
|
self._ent_coef = ent_coef
|
||||||
|
self._fraction_optim = fraction_optim
|
||||||
|
|
||||||
|
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
||||||
|
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||||
|
if self._target:
|
||||||
|
result = self(batch, input="obs_next")
|
||||||
|
a, fractions = result.act, result.fractions
|
||||||
|
next_dist = self(
|
||||||
|
batch, model="model_old", input="obs_next", fractions=fractions
|
||||||
|
).logits
|
||||||
|
else:
|
||||||
|
next_b = self(batch, input="obs_next")
|
||||||
|
a = next_b.act
|
||||||
|
next_dist = next_b.logits
|
||||||
|
next_dist = next_dist[np.arange(len(a)), a, :]
|
||||||
|
return next_dist # shape: [bsz, num_quantiles]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: Batch,
|
||||||
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
|
model: str = "model",
|
||||||
|
input: str = "obs",
|
||||||
|
fractions: Optional[Batch] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Batch:
|
||||||
|
model = getattr(self, model)
|
||||||
|
obs = batch[input]
|
||||||
|
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||||
|
if fractions is None:
|
||||||
|
(logits, fractions, quantiles_tau), h = model(
|
||||||
|
obs_, propose_model=self.propose_model, state=state, info=batch.info
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(logits, _, quantiles_tau), h = model(
|
||||||
|
obs_, propose_model=self.propose_model, fractions=fractions,
|
||||||
|
state=state, info=batch.info
|
||||||
|
)
|
||||||
|
weighted_logits = (
|
||||||
|
fractions.taus[:, 1:] - fractions.taus[:, :-1]
|
||||||
|
).unsqueeze(1) * logits
|
||||||
|
q = DQNPolicy.compute_q_value(
|
||||||
|
self, weighted_logits.sum(2), getattr(obs, "mask", None)
|
||||||
|
)
|
||||||
|
if not hasattr(self, "max_action_num"):
|
||||||
|
self.max_action_num = q.shape[1]
|
||||||
|
act = to_numpy(q.max(dim=1)[1])
|
||||||
|
return Batch(
|
||||||
|
logits=logits, act=act, state=h, fractions=fractions,
|
||||||
|
quantiles_tau=quantiles_tau
|
||||||
|
)
|
||||||
|
|
||||||
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||||
|
if self._target and self._iter % self._freq == 0:
|
||||||
|
self.sync_weight()
|
||||||
|
weight = batch.pop("weight", 1.0)
|
||||||
|
out = self(batch)
|
||||||
|
curr_dist_orig = out.logits
|
||||||
|
taus, tau_hats = out.fractions.taus, out.fractions.tau_hats
|
||||||
|
act = batch.act
|
||||||
|
curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
|
||||||
|
target_dist = batch.returns.unsqueeze(1)
|
||||||
|
# calculate each element's difference between curr_dist and target_dist
|
||||||
|
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||||
|
huber_loss = (u * (
|
||||||
|
tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()
|
||||||
|
).abs()).sum(-1).mean(1)
|
||||||
|
quantile_loss = (huber_loss * weight).mean()
|
||||||
|
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||||
|
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
|
||||||
|
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||||
|
# calculate fraction loss
|
||||||
|
with torch.no_grad():
|
||||||
|
sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]
|
||||||
|
sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :]
|
||||||
|
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||||
|
# blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169
|
||||||
|
values_1 = sa_quantiles - sa_quantile_hats[:, :-1]
|
||||||
|
signs_1 = sa_quantiles > torch.cat([
|
||||||
|
sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1)
|
||||||
|
|
||||||
|
values_2 = sa_quantiles - sa_quantile_hats[:, 1:]
|
||||||
|
signs_2 = sa_quantiles < torch.cat([
|
||||||
|
sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1)
|
||||||
|
|
||||||
|
gradient_of_taus = (
|
||||||
|
torch.where(signs_1, values_1, -values_1)
|
||||||
|
+ torch.where(signs_2, values_2, -values_2)
|
||||||
|
)
|
||||||
|
fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean()
|
||||||
|
# calculate entropy loss
|
||||||
|
entropy_loss = out.fractions.entropies.mean()
|
||||||
|
fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss
|
||||||
|
self._fraction_optim.zero_grad()
|
||||||
|
fraction_entropy_loss.backward(retain_graph=True)
|
||||||
|
self._fraction_optim.step()
|
||||||
|
self.optim.zero_grad()
|
||||||
|
quantile_loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
self._iter += 1
|
||||||
|
return {
|
||||||
|
"loss": quantile_loss.item() + fraction_entropy_loss.item(),
|
||||||
|
"loss/quantile": quantile_loss.item(),
|
||||||
|
"loss/fraction": fraction_loss.item(),
|
||||||
|
"loss/entropy": entropy_loss.item()
|
||||||
|
}
|
||||||
@ -4,6 +4,7 @@ from torch import nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||||
|
|
||||||
|
from tianshou.data import Batch
|
||||||
from tianshou.utils.net.common import MLP
|
from tianshou.utils.net.common import MLP
|
||||||
|
|
||||||
|
|
||||||
@ -199,6 +200,110 @@ class ImplicitQuantileNetwork(Critic):
|
|||||||
embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view(
|
embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view(
|
||||||
batch_size * sample_size, -1
|
batch_size * sample_size, -1
|
||||||
)
|
)
|
||||||
out = self.last(embedding).view(batch_size,
|
out = self.last(embedding).view(
|
||||||
sample_size, -1).transpose(1, 2)
|
batch_size, sample_size, -1).transpose(1, 2)
|
||||||
return (out, taus), h
|
return (out, taus), h
|
||||||
|
|
||||||
|
|
||||||
|
class FractionProposalNetwork(nn.Module):
|
||||||
|
"""Fraction proposal network for FQF.
|
||||||
|
|
||||||
|
:param num_fractions: the number of factions to propose.
|
||||||
|
:param embedding_dim: the dimension of the embedding/input.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master
|
||||||
|
/fqf_iqn_qrdqn/network.py .
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_fractions: int, embedding_dim: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Linear(embedding_dim, num_fractions)
|
||||||
|
torch.nn.init.xavier_uniform_(self.net.weight, gain=0.01)
|
||||||
|
torch.nn.init.constant_(self.net.bias, 0)
|
||||||
|
self.num_fractions = num_fractions
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, state_embeddings: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
# Calculate (log of) probabilities q_i in the paper.
|
||||||
|
m = torch.distributions.Categorical(logits=self.net(state_embeddings))
|
||||||
|
taus_1_N = torch.cumsum(m.probs, dim=1)
|
||||||
|
# Calculate \tau_i (i=0,...,N).
|
||||||
|
taus = F.pad(taus_1_N, (1, 0))
|
||||||
|
# Calculate \hat \tau_i (i=0,...,N-1).
|
||||||
|
tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0
|
||||||
|
# Calculate entropies of value distributions.
|
||||||
|
entropies = m.entropy()
|
||||||
|
return taus, tau_hats, entropies
|
||||||
|
|
||||||
|
|
||||||
|
class FullQuantileFunction(ImplicitQuantileNetwork):
|
||||||
|
"""Full(y parameterized) Quantile Function.
|
||||||
|
|
||||||
|
:param preprocess_net: a self-defined preprocess_net which output a
|
||||||
|
flattened hidden state.
|
||||||
|
:param int action_dim: the dimension of action space.
|
||||||
|
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||||
|
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||||
|
only a single linear layer).
|
||||||
|
:param int num_cosines: the number of cosines to use for cosine embedding.
|
||||||
|
Default to 64.
|
||||||
|
:param int preprocess_net_output_dim: the output dimension of
|
||||||
|
preprocess_net.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The first return value is a tuple of (quantiles, fractions, quantiles_tau),
|
||||||
|
where fractions is a Batch(taus, tau_hats, entropies).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
preprocess_net: nn.Module,
|
||||||
|
action_shape: Sequence[int],
|
||||||
|
hidden_sizes: Sequence[int] = (),
|
||||||
|
num_cosines: int = 64,
|
||||||
|
preprocess_net_output_dim: Optional[int] = None,
|
||||||
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
preprocess_net, action_shape, hidden_sizes,
|
||||||
|
num_cosines, preprocess_net_output_dim, device
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_quantiles(
|
||||||
|
self, obs: torch.Tensor, taus: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, sample_size = taus.shape
|
||||||
|
embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view(
|
||||||
|
batch_size * sample_size, -1
|
||||||
|
)
|
||||||
|
quantiles = self.last(embedding).view(
|
||||||
|
batch_size, sample_size, -1
|
||||||
|
).transpose(1, 2)
|
||||||
|
return quantiles
|
||||||
|
|
||||||
|
def forward( # type: ignore
|
||||||
|
self, s: Union[np.ndarray, torch.Tensor],
|
||||||
|
propose_model: FractionProposalNetwork,
|
||||||
|
fractions: Optional[Batch] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Tuple[Any, torch.Tensor]:
|
||||||
|
r"""Mapping: s -> Q(s, \*)."""
|
||||||
|
logits, h = self.preprocess(s, state=kwargs.get("state", None))
|
||||||
|
# Propose fractions
|
||||||
|
if fractions is None:
|
||||||
|
taus, tau_hats, entropies = propose_model(logits.detach())
|
||||||
|
fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies)
|
||||||
|
else:
|
||||||
|
taus, tau_hats = fractions.taus, fractions.tau_hats
|
||||||
|
quantiles = self._compute_quantiles(logits, tau_hats)
|
||||||
|
# Calculate quantiles_tau for computing fraction grad
|
||||||
|
quantiles_tau = None
|
||||||
|
if self.training:
|
||||||
|
with torch.no_grad():
|
||||||
|
quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1])
|
||||||
|
return (quantiles, fractions, quantiles_tau), h
|
||||||
|
|||||||