Implement Generative Adversarial Imitation Learning (GAIL) (#550)

Implement GAIL based on PPO and provide example script and sample (i.e., most likely not the best) results with Mujoco tasks. (#531, #173)
This commit is contained in:
Yi Su 2022-03-06 07:57:15 -08:00 committed by GitHub
parent d976a5aa91
commit 2377f2f186
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 683 additions and 0 deletions

View File

@ -41,6 +41,7 @@
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/pdf/1606.03476.pdf)
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)

View File

@ -134,6 +134,11 @@ Imitation
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.GAILPolicy
:members:
:undoc-members:
:show-inheritance:
Model-based
-----------

View File

@ -32,6 +32,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
* :class:`~tianshou.policy.GAILPolicy` `Generative Adversarial Imitation Learning <https://arxiv.org/pdf/1606.03476.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_

View File

@ -140,3 +140,6 @@ Strens
Ornstein
Uhlenbeck
mse
gail
airl
ppo

View File

@ -0,0 +1,27 @@
# Inverse Reinforcement Learning
In inverse reinforcement learning setting, the agent learns a policy from interaction with an environment without reward and a fixed dataset which is collected with an expert policy.
## Continuous control
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
We provide implementation of GAIL algorithm for continuous control.
### Train
You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `expert_buffer` of `GAILPolicy`. `irl_gail.py` is an example of inverse RL using the d4rl dataset.
To train an agent with BCQ algorithm:
```bash
python irl_gail.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2
```
## GAIL (single run)
| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| HalfCheetah-v2 | 5177.07 | ![](results/gail/HalfCheetah-v2_rew.png) | `python3 irl_gail.py --task "HalfCheetah-v2" --expert-data-task "halfcheetah-expert-v2"` |
| Hopper-v2 | 1761.44 | ![](results/gail/Hopper-v2_rew.png) | `python3 irl_gail.py --task "Hopper-v2" --expert-data-task "hopper-expert-v2"` |
| Walker2d-v2 | 2020.77 | ![](results/gail/Walker2d-v2_rew.png) | `python3 irl_gail.py --task "Walker2d-v2" --expert-data-task "walker2d-expert-v2"` |

View File

@ -0,0 +1,277 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import d4rl
import gym
import numpy as np
import torch
from torch import nn
from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import GAILPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
class NoRewardEnv(gym.RewardWrapper):
"""sets the reward to 0.
:param gym.Env env: the environment to wrap.
"""
def __init__(self, env):
super().__init__(env)
def reward(self, reward):
"""Set reward to 0."""
return np.zeros_like(reward)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='HalfCheetah-v2')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument(
'--expert-data-task', type=str, default='halfcheetah-expert-v2'
)
parser.add_argument('--buffer-size', type=int, default=4096)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--disc-lr', type=float, default=2.5e-5)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--disc-update-num', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--training-num', type=int, default=64)
parser.add_argument('--test-num', type=int, default=10)
# ppo special
parser.add_argument('--rew-norm', type=int, default=True)
# In theory, `vf-coef` will not make any difference if using Adam optimizer.
parser.add_argument('--vf-coef', type=float, default=0.25)
parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=0)
parser.add_argument('--recompute-adv', type=int, default=1)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
)
return parser.parse_args()
def test_gail(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.training_num)],
norm_obs=True
)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
norm_obs=True,
obs_rms=train_envs.obs_rms,
update_obs_rms=False
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device
)
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device
).to(args.device)
net_c = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device
)
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
# discriminator
net_d = Net(
args.state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
concat=True
)
disc_net = Critic(net_d, device=args.device).to(args.device)
for m in disc_net.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr)
lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect
) * args.epoch
lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
)
def dist(*logits):
return Independent(Normal(*logits), 1)
# expert replay buffer
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
dataset_size = dataset['rewards'].size
print("dataset_size", dataset_size)
expert_buffer = ReplayBuffer(dataset_size)
for i in range(dataset_size):
expert_buffer.add(
Batch(
obs=dataset['observations'][i],
act=dataset['actions'][i],
rew=dataset['rewards'][i],
done=dataset['terminals'][i],
obs_next=dataset['next_observations'][i],
)
)
print("dataset loaded")
policy = GAILPolicy(
actor,
critic,
optim,
dist,
expert_buffer,
disc_net,
disc_optim,
disc_update_num=args.disc_update_num,
discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
max_grad_norm=args.max_grad_norm,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler,
action_space=env.action_space,
eps_clip=args.eps_clip,
value_clip=args.value_clip,
dual_clip=args.dual_clip,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail'
log_path = os.path.join(args.logdir, args.task, 'gail', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
if not args.watch:
# trainer
result = onpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.repeat_per_collect,
args.test_num,
args.batch_size,
step_per_collect=args.step_per_collect,
save_fn=save_fn,
logger=logger,
test_in_train=False
)
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
if __name__ == '__main__':
test_gail()

Binary file not shown.

After

Width:  |  Height:  |  Size: 199 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 233 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 205 KiB

228
test/offline/test_gail.py Normal file
View File

@ -0,0 +1,228 @@
import argparse
import os
import pickle
import pprint
import gym
import numpy as np
import torch
from torch.distributions import Independent, Normal
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import GAILPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v1')
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--disc-lr', type=float, default=5e-4)
parser.add_argument('--gamma', type=float, default=0.95)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=150000)
parser.add_argument('--episode-per-collect', type=int, default=16)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--disc-update-num', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
# ppo special
parser.add_argument('--vf-coef', type=float, default=0.25)
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--resume', action="store_true")
parser.add_argument("--save-interval", type=int, default=4)
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
args = parser.parse_known_args()[0]
return args
def test_gail(args=get_args()):
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
env = gym.make(args.task)
if args.reward_threshold is None:
default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100}
args.reward_threshold = default_reward_threshold.get(
args.task, env.spec.reward_threshold
)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)]
)
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
net, args.action_shape, max_action=args.max_action, device=args.device
).to(args.device)
critic = Critic(
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
device=args.device
).to(args.device)
actor_critic = ActorCritic(actor, critic)
# orthogonal initialization
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
# discriminator
disc_net = Critic(
Net(
args.state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
activation=torch.nn.Tanh,
device=args.device,
concat=True,
),
device=args.device
).to(args.device)
for m in disc_net.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr)
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = GAILPolicy(
actor,
critic,
optim,
dist,
buffer,
disc_net,
disc_optim,
disc_update_num=args.disc_update_num,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv,
dual_clip=args.dual_clip,
value_clip=args.value_clip,
gae_lambda=args.gae_lambda,
action_space=env.action_space,
)
# collector
train_collector = Collector(
policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'gail')
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
)
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
# trainer
result = onpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.repeat_per_collect,
args.test_num,
args.batch_size,
episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn,
)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == '__main__':
test_gail()

View File

@ -24,6 +24,7 @@ from tianshou.policy.imitation.cql import CQLPolicy
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
from tianshou.policy.imitation.gail import GAILPolicy
from tianshou.policy.modelbased.psrl import PSRLPolicy
from tianshou.policy.modelbased.icm import ICMPolicy
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
@ -52,6 +53,7 @@ __all__ = [
"DiscreteBCQPolicy",
"DiscreteCQLPolicy",
"DiscreteCRRPolicy",
"GAILPolicy",
"PSRLPolicy",
"ICMPolicy",
"MultiAgentPolicyManager",

View File

@ -0,0 +1,139 @@
from typing import Any, Dict, List, Optional, Type
import numpy as np
import torch
import torch.nn.functional as F
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
from tianshou.policy import PPOPolicy
class GAILPolicy(PPOPolicy):
r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution]
:param ReplayBuffer expert_buffer: the replay buffer contains expert experience.
:param torch.nn.Module disc_net: the discriminator network with input dim equals
state dim plus action dim and output dim equals 1.
:param torch.optim.Optimizer disc_optim: the optimizer for the discriminator
network.
:param int disc_update_num: the number of discriminator grad steps per model grad
step. Default to 4.
:param float discount_factor: in [0, 1]. Default to 0.99.
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
paper. Default to 0.2.
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
where c > 1 is a constant indicating the lower bound.
Default to 5.0 (set None if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
Default to True.
:param bool advantage_normalization: whether to do per mini-batch advantage
normalization. Default to True.
:param bool recompute_advantage: whether to recompute advantage every update
repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
Default to False.
:param float vf_coef: weight for value loss. Default to 0.5.
:param float ent_coef: weight for entropy loss. Default to 0.01.
:param float max_grad_norm: clipping gradients in back propagation. Default to
None.
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
Default to 0.95.
:param bool reward_normalization: normalize estimated values to have std close
to 1, also normalize the advantage to Normal(0, 1). Default to False.
:param int max_batchsize: the maximum size of the batch when computing GAE,
depends on the size of available memory and the memory cost of the model;
should be as large as possible within the memory constraint. Default to 256.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
.. seealso::
Please refer to :class:`~tianshou.policy.PPOPolicy` for more detailed
explanation.
"""
def __init__(
self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution],
expert_buffer: ReplayBuffer,
disc_net: torch.nn.Module,
disc_optim: torch.optim.Optimizer,
disc_update_num: int = 4,
eps_clip: float = 0.2,
dual_clip: Optional[float] = None,
value_clip: bool = False,
advantage_normalization: bool = True,
recompute_advantage: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
actor, critic, optim, dist_fn, eps_clip, dual_clip, value_clip,
advantage_normalization, recompute_advantage, **kwargs
)
self.disc_net = disc_net
self.disc_optim = disc_optim
self.disc_update_num = disc_update_num
self.expert_buffer = expert_buffer
self.action_dim = actor.output_dim
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
"""Pre-process the data from the provided replay buffer.
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
"""
# update reward
with torch.no_grad():
batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten())
return super().process_fn(batch, buffer, indices)
def disc(self, batch: Batch) -> torch.Tensor:
obs = to_torch(batch.obs, device=self.disc_net.device) # type: ignore
act = to_torch(batch.act, device=self.disc_net.device) # type: ignore
return self.disc_net(torch.cat([obs, act], dim=1)) # type: ignore
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
# update discriminator
losses = []
acc_pis = []
acc_exps = []
bsz = len(batch) // self.disc_update_num
for b in batch.split(bsz, merge_last=True):
logits_pi = self.disc(b)
exp_b = self.expert_buffer.sample(bsz)[0]
logits_exp = self.disc(exp_b)
loss_pi = -F.logsigmoid(-logits_pi).mean()
loss_exp = -F.logsigmoid(logits_exp).mean()
loss_disc = loss_pi + loss_exp
self.disc_optim.zero_grad()
loss_disc.backward()
self.disc_optim.step()
losses.append(loss_disc.item())
acc_pis.append((logits_pi < 0).float().mean().item())
acc_exps.append((logits_exp > 0).float().mean().item())
# update policy
res = super().learn(batch, batch_size, repeat, **kwargs)
res["loss/disc"] = losses
res["stats/acc_pi"] = acc_pis
res["stats/acc_exp"] = acc_exps
return res