Implement Generative Adversarial Imitation Learning (GAIL) (#550)
Implement GAIL based on PPO and provide example script and sample (i.e., most likely not the best) results with Mujoco tasks. (#531, #173)
This commit is contained in:
parent
d976a5aa91
commit
2377f2f186
@ -41,6 +41,7 @@
|
||||
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
||||
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
|
||||
- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/pdf/1606.03476.pdf)
|
||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
|
||||
|
@ -134,6 +134,11 @@ Imitation
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.GAILPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Model-based
|
||||
-----------
|
||||
|
||||
|
@ -32,6 +32,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
||||
* :class:`~tianshou.policy.GAILPolicy` `Generative Adversarial Imitation Learning <https://arxiv.org/pdf/1606.03476.pdf>`_
|
||||
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
|
||||
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
|
||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
|
@ -140,3 +140,6 @@ Strens
|
||||
Ornstein
|
||||
Uhlenbeck
|
||||
mse
|
||||
gail
|
||||
airl
|
||||
ppo
|
||||
|
27
examples/inverse/README.md
Normal file
27
examples/inverse/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# Inverse Reinforcement Learning
|
||||
|
||||
In inverse reinforcement learning setting, the agent learns a policy from interaction with an environment without reward and a fixed dataset which is collected with an expert policy.
|
||||
|
||||
## Continuous control
|
||||
|
||||
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
|
||||
|
||||
We provide implementation of GAIL algorithm for continuous control.
|
||||
|
||||
### Train
|
||||
|
||||
You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `expert_buffer` of `GAILPolicy`. `irl_gail.py` is an example of inverse RL using the d4rl dataset.
|
||||
|
||||
To train an agent with BCQ algorithm:
|
||||
|
||||
```bash
|
||||
python irl_gail.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2
|
||||
```
|
||||
|
||||
## GAIL (single run)
|
||||
|
||||
| task | best reward | reward curve | parameters |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
|
||||
| HalfCheetah-v2 | 5177.07 |  | `python3 irl_gail.py --task "HalfCheetah-v2" --expert-data-task "halfcheetah-expert-v2"` |
|
||||
| Hopper-v2 | 1761.44 |  | `python3 irl_gail.py --task "Hopper-v2" --expert-data-task "hopper-expert-v2"` |
|
||||
| Walker2d-v2 | 2020.77 |  | `python3 irl_gail.py --task "Walker2d-v2" --expert-data-task "walker2d-expert-v2"` |
|
277
examples/inverse/irl_gail.py
Normal file
277
examples/inverse/irl_gail.py
Normal file
@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import d4rl
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributions import Independent, Normal
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import GAILPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
class NoRewardEnv(gym.RewardWrapper):
|
||||
"""sets the reward to 0.
|
||||
|
||||
:param gym.Env env: the environment to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def reward(self, reward):
|
||||
"""Set reward to 0."""
|
||||
return np.zeros_like(reward)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='HalfCheetah-v2')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--expert-data-task', type=str, default='halfcheetah-expert-v2'
|
||||
)
|
||||
parser.add_argument('--buffer-size', type=int, default=4096)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--disc-lr', type=float, default=2.5e-5)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=30000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=2048)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=10)
|
||||
parser.add_argument('--disc-update-num', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--training-num', type=int, default=64)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
# ppo special
|
||||
parser.add_argument('--rew-norm', type=int, default=True)
|
||||
# In theory, `vf-coef` will not make any difference if using Adam optimizer.
|
||||
parser.add_argument('--vf-coef', type=float, default=0.25)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.001)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--bound-action-method', type=str, default="clip")
|
||||
parser.add_argument('--lr-decay', type=int, default=True)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=int, default=0)
|
||||
parser.add_argument('--norm-adv', type=int, default=0)
|
||||
parser.add_argument('--recompute-adv', type=int, default=1)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='watch the play of pre-trained policy only'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_gail(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.training_num)],
|
||||
norm_obs=True
|
||||
)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
norm_obs=True,
|
||||
obs_rms=train_envs.obs_rms,
|
||||
update_obs_rms=False
|
||||
)
|
||||
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net_a = Net(
|
||||
args.state_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
activation=nn.Tanh,
|
||||
device=args.device
|
||||
)
|
||||
actor = ActorProb(
|
||||
net_a,
|
||||
args.action_shape,
|
||||
max_action=args.max_action,
|
||||
unbounded=True,
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
net_c = Net(
|
||||
args.state_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
activation=nn.Tanh,
|
||||
device=args.device
|
||||
)
|
||||
critic = Critic(net_c, device=args.device).to(args.device)
|
||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
# orthogonal initialization
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
# do last policy layer scaling, this will make initial actions have (close to)
|
||||
# 0 mean and std, and will help boost performances,
|
||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||
for m in actor.mu.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
|
||||
# discriminator
|
||||
net_d = Net(
|
||||
args.state_shape,
|
||||
action_shape=args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
activation=nn.Tanh,
|
||||
device=args.device,
|
||||
concat=True
|
||||
)
|
||||
disc_net = Critic(net_d, device=args.device).to(args.device)
|
||||
for m in disc_net.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
# orthogonal initialization
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr)
|
||||
|
||||
lr_scheduler = None
|
||||
if args.lr_decay:
|
||||
# decay learning rate to 0 linearly
|
||||
max_update_num = np.ceil(
|
||||
args.step_per_epoch / args.step_per_collect
|
||||
) * args.epoch
|
||||
|
||||
lr_scheduler = LambdaLR(
|
||||
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
|
||||
)
|
||||
|
||||
def dist(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
# expert replay buffer
|
||||
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
|
||||
dataset_size = dataset['rewards'].size
|
||||
|
||||
print("dataset_size", dataset_size)
|
||||
expert_buffer = ReplayBuffer(dataset_size)
|
||||
|
||||
for i in range(dataset_size):
|
||||
expert_buffer.add(
|
||||
Batch(
|
||||
obs=dataset['observations'][i],
|
||||
act=dataset['actions'][i],
|
||||
rew=dataset['rewards'][i],
|
||||
done=dataset['terminals'][i],
|
||||
obs_next=dataset['next_observations'][i],
|
||||
)
|
||||
)
|
||||
print("dataset loaded")
|
||||
|
||||
policy = GAILPolicy(
|
||||
actor,
|
||||
critic,
|
||||
optim,
|
||||
dist,
|
||||
expert_buffer,
|
||||
disc_net,
|
||||
disc_optim,
|
||||
disc_update_num=args.disc_update_num,
|
||||
discount_factor=args.gamma,
|
||||
gae_lambda=args.gae_lambda,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
reward_normalization=args.rew_norm,
|
||||
action_scaling=True,
|
||||
action_bound_method=args.bound_action_method,
|
||||
lr_scheduler=lr_scheduler,
|
||||
action_space=env.action_space,
|
||||
eps_clip=args.eps_clip,
|
||||
value_clip=args.value_clip,
|
||||
dual_clip=args.dual_clip,
|
||||
advantage_normalization=args.norm_adv,
|
||||
recompute_advantage=args.recompute_adv
|
||||
)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
if args.training_num > 1:
|
||||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail'
|
||||
log_path = os.path.join(args.logdir, args.task, 'gail', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
if not args.watch:
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.repeat_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
step_per_collect=args.step_per_collect,
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
test_in_train=False
|
||||
)
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gail()
|
BIN
examples/inverse/results/gail/HalfCheetah-v2_rew.png
Normal file
BIN
examples/inverse/results/gail/HalfCheetah-v2_rew.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 199 KiB |
BIN
examples/inverse/results/gail/Hopper-v2_rew.png
Normal file
BIN
examples/inverse/results/gail/Hopper-v2_rew.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 233 KiB |
BIN
examples/inverse/results/gail/Walker2d-v2_rew.png
Normal file
BIN
examples/inverse/results/gail/Walker2d-v2_rew.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 205 KiB |
228
test/offline/test_gail.py
Normal file
228
test/offline/test_gail.py
Normal file
@ -0,0 +1,228 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Independent, Normal
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import GAILPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_pendulum_data import expert_file_name, gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v1')
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--disc-lr', type=float, default=5e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.95)
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=150000)
|
||||
parser.add_argument('--episode-per-collect', type=int, default=16)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--disc-update-num', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
# ppo special
|
||||
parser.add_argument('--vf-coef', type=float, default=0.25)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=int, default=1)
|
||||
parser.add_argument('--norm-adv', type=int, default=1)
|
||||
parser.add_argument('--recompute-adv', type=int, default=0)
|
||||
parser.add_argument('--resume', action="store_true")
|
||||
parser.add_argument("--save-interval", type=int, default=4)
|
||||
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_gail(args=get_args()):
|
||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||
if args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
else:
|
||||
buffer = gather_data()
|
||||
env = gym.make(args.task)
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task, env.spec.reward_threshold
|
||||
)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||
)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = ActorProb(
|
||||
net, args.action_shape, max_action=args.max_action, device=args.device
|
||||
).to(args.device)
|
||||
critic = Critic(
|
||||
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
# orthogonal initialization
|
||||
for m in actor_critic.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
|
||||
# discriminator
|
||||
disc_net = Critic(
|
||||
Net(
|
||||
args.state_shape,
|
||||
action_shape=args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
activation=torch.nn.Tanh,
|
||||
device=args.device,
|
||||
concat=True,
|
||||
),
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
for m in disc_net.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
# orthogonal initialization
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr)
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
policy = GAILPolicy(
|
||||
actor,
|
||||
critic,
|
||||
optim,
|
||||
dist,
|
||||
buffer,
|
||||
disc_net,
|
||||
disc_optim,
|
||||
disc_update_num=args.disc_update_num,
|
||||
discount_factor=args.gamma,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
eps_clip=args.eps_clip,
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
reward_normalization=args.rew_norm,
|
||||
advantage_normalization=args.norm_adv,
|
||||
recompute_advantage=args.recompute_adv,
|
||||
dual_clip=args.dual_clip,
|
||||
value_clip=args.value_clip,
|
||||
gae_lambda=args.gae_lambda,
|
||||
action_space=env.action_space,
|
||||
)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'gail')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
torch.save(
|
||||
{
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth')
|
||||
)
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
optim.load_state_dict(checkpoint['optim'])
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.repeat_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
episode_per_collect=args.episode_per_collect,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
resume_from_log=args.resume,
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gail()
|
@ -24,6 +24,7 @@ from tianshou.policy.imitation.cql import CQLPolicy
|
||||
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
||||
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
||||
from tianshou.policy.imitation.gail import GAILPolicy
|
||||
from tianshou.policy.modelbased.psrl import PSRLPolicy
|
||||
from tianshou.policy.modelbased.icm import ICMPolicy
|
||||
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||
@ -52,6 +53,7 @@ __all__ = [
|
||||
"DiscreteBCQPolicy",
|
||||
"DiscreteCQLPolicy",
|
||||
"DiscreteCRRPolicy",
|
||||
"GAILPolicy",
|
||||
"PSRLPolicy",
|
||||
"ICMPolicy",
|
||||
"MultiAgentPolicyManager",
|
||||
|
139
tianshou/policy/imitation/gail.py
Normal file
139
tianshou/policy/imitation/gail.py
Normal file
@ -0,0 +1,139 @@
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
|
||||
from tianshou.policy import PPOPolicy
|
||||
|
||||
|
||||
class GAILPolicy(PPOPolicy):
|
||||
r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
:type dist_fn: Type[torch.distributions.Distribution]
|
||||
:param ReplayBuffer expert_buffer: the replay buffer contains expert experience.
|
||||
:param torch.nn.Module disc_net: the discriminator network with input dim equals
|
||||
state dim plus action dim and output dim equals 1.
|
||||
:param torch.optim.Optimizer disc_optim: the optimizer for the discriminator
|
||||
network.
|
||||
:param int disc_update_num: the number of discriminator grad steps per model grad
|
||||
step. Default to 4.
|
||||
:param float discount_factor: in [0, 1]. Default to 0.99.
|
||||
:param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
|
||||
paper. Default to 0.2.
|
||||
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
||||
where c > 1 is a constant indicating the lower bound.
|
||||
Default to 5.0 (set None if you do not want to use it).
|
||||
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
|
||||
Default to True.
|
||||
:param bool advantage_normalization: whether to do per mini-batch advantage
|
||||
normalization. Default to True.
|
||||
:param bool recompute_advantage: whether to recompute advantage every update
|
||||
repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
|
||||
Default to False.
|
||||
:param float vf_coef: weight for value loss. Default to 0.5.
|
||||
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
||||
:param float max_grad_norm: clipping gradients in back propagation. Default to
|
||||
None.
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
|
||||
Default to 0.95.
|
||||
:param bool reward_normalization: normalize estimated values to have std close
|
||||
to 1, also normalize the advantage to Normal(0, 1). Default to False.
|
||||
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
||||
depends on the size of available memory and the memory cost of the model;
|
||||
should be as large as possible within the memory constraint. Default to 256.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||
stochastic action sampled by the policy. Default to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.policy.PPOPolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Type[torch.distributions.Distribution],
|
||||
expert_buffer: ReplayBuffer,
|
||||
disc_net: torch.nn.Module,
|
||||
disc_optim: torch.optim.Optimizer,
|
||||
disc_update_num: int = 4,
|
||||
eps_clip: float = 0.2,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = False,
|
||||
advantage_normalization: bool = True,
|
||||
recompute_advantage: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
actor, critic, optim, dist_fn, eps_clip, dual_clip, value_clip,
|
||||
advantage_normalization, recompute_advantage, **kwargs
|
||||
)
|
||||
self.disc_net = disc_net
|
||||
self.disc_optim = disc_optim
|
||||
self.disc_update_num = disc_update_num
|
||||
self.expert_buffer = expert_buffer
|
||||
self.action_dim = actor.output_dim
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||
) -> Batch:
|
||||
"""Pre-process the data from the provided replay buffer.
|
||||
|
||||
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
|
||||
"""
|
||||
# update reward
|
||||
with torch.no_grad():
|
||||
batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten())
|
||||
return super().process_fn(batch, buffer, indices)
|
||||
|
||||
def disc(self, batch: Batch) -> torch.Tensor:
|
||||
obs = to_torch(batch.obs, device=self.disc_net.device) # type: ignore
|
||||
act = to_torch(batch.act, device=self.disc_net.device) # type: ignore
|
||||
return self.disc_net(torch.cat([obs, act], dim=1)) # type: ignore
|
||||
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
# update discriminator
|
||||
losses = []
|
||||
acc_pis = []
|
||||
acc_exps = []
|
||||
bsz = len(batch) // self.disc_update_num
|
||||
for b in batch.split(bsz, merge_last=True):
|
||||
logits_pi = self.disc(b)
|
||||
exp_b = self.expert_buffer.sample(bsz)[0]
|
||||
logits_exp = self.disc(exp_b)
|
||||
loss_pi = -F.logsigmoid(-logits_pi).mean()
|
||||
loss_exp = -F.logsigmoid(logits_exp).mean()
|
||||
loss_disc = loss_pi + loss_exp
|
||||
self.disc_optim.zero_grad()
|
||||
loss_disc.backward()
|
||||
self.disc_optim.step()
|
||||
losses.append(loss_disc.item())
|
||||
acc_pis.append((logits_pi < 0).float().mean().item())
|
||||
acc_exps.append((logits_exp > 0).float().mean().item())
|
||||
# update policy
|
||||
res = super().learn(batch, batch_size, repeat, **kwargs)
|
||||
res["loss/disc"] = losses
|
||||
res["stats/acc_pi"] = acc_pis
|
||||
res["stats/acc_exp"] = acc_exps
|
||||
return res
|
Loading…
x
Reference in New Issue
Block a user