Implement BCQPolicy and offline_bcq example (#480)

This PR implements BCQPolicy, which could be used to train an offline agent in the environment of continuous action space. An experimental result 'halfcheetah-expert-v1' is provided, which is a d4rl environment (for Offline Reinforcement Learning).
Example usage is in the examples/offline/offline_bcq.py.
This commit is contained in:
Bernard Tan 2021-11-22 22:21:02 +08:00 committed by GitHub
parent 94d3b27db9
commit 5c5a3db94e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1003 additions and 1 deletions

View File

@ -36,6 +36,7 @@
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)

View File

@ -109,6 +109,11 @@ Imitation
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.BCQPolicy
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
:members:
:undoc-members:

View File

@ -27,6 +27,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_

View File

@ -0,0 +1,28 @@
# Offline
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
## Train
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
To train an agent with BCQ algorithm:
```bash
python offline_bcq.py --task halfcheetah-expert-v1
```
After 1M steps:
![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png)
`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment.
## Results
| Environment | BCQ |
| --------------------- | --------------- |
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |

View File

@ -0,0 +1,241 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import d4rl
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import BCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='halfcheetah-expert-v1')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=1000000)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[400, 300])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument('--epoch', type=int, default=200)
parser.add_argument('--step-per-epoch', type=int, default=5000)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1 / 35)
parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[750, 750])
# default to 2 * action_dim
parser.add_argument('--latent-dim', type=int)
parser.add_argument("--gamma", default=0.99)
parser.add_argument("--tau", default=0.005)
# Weighting for Clipped Double Q-learning in BCQ
parser.add_argument("--lmbda", default=0.75)
# Max perturbation hyper-parameter for BCQ
parser.add_argument("--phi", default=0.05)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only',
)
return parser.parse_args()
def test_bcq():
args = get_args()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
print("device:", args.device)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
print("Max_action", args.max_action)
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)]
)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
# perturbation network
net_a = MLP(
input_dim=args.state_dim + args.action_dim,
output_dim=args.action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Perturbation(
net_a, max_action=args.max_action, device=args.device, phi=args.phi
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
# vae
# output_dim = 0, so the last Module in the encoder is ReLU
vae_encoder = MLP(
input_dim=args.state_dim + args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
if not args.latent_dim:
args.latent_dim = args.action_dim * 2
vae_decoder = MLP(
input_dim=args.state_dim + args.latent_dim,
output_dim=args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
vae = VAE(
vae_encoder,
vae_decoder,
hidden_dim=args.vae_hidden_sizes[-1],
latent_dim=args.latent_dim,
max_action=args.max_action,
device=args.device,
).to(args.device)
vae_optim = torch.optim.Adam(vae.parameters())
policy = BCQPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
vae,
vae_optim,
device=args.device,
gamma=args.gamma,
tau=args.tau,
lmbda=args.lmbda,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq'
log_path = os.path.join(args.logdir, args.task, 'bcq', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def watch():
if args.resume_path is None:
args.resume_path = os.path.join(log_path, 'policy.pth')
policy.load_state_dict(
torch.load(args.resume_path, map_location=torch.device('cpu'))
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
dataset = d4rl.qlearning_dataset(env)
dataset_size = dataset['rewards'].size
print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)
for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset['observations'][i],
act=dataset['actions'][i],
rew=dataset['rewards'][i],
done=dataset['terminals'][i],
obs_next=dataset['next_observations'][i],
)
)
print("dataset loaded")
# trainer
result = offline_trainer(
policy,
replay_buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_fn=save_fn,
logger=logger,
)
pprint.pprint(result)
else:
watch()
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
if __name__ == '__main__':
test_bcq()

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 28 KiB

View File

@ -134,7 +134,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
SubprocVectorEnv(env_fns),
ShmemVectorEnv(env_fns),
]
if has_ray():
if has_ray() and sys.platform == "linux":
venv += [RayVectorEnv(env_fns)]
for v in venv:
v.seed(0)

0
test/offline/__init__.py Normal file
View File

View File

@ -0,0 +1,170 @@
import argparse
import os
import pickle
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=200000)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--epoch', type=int, default=7)
parser.add_argument('--step-per-epoch', type=int, default=8000)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.125)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--gamma", default=0.99)
parser.add_argument("--tau", default=0.005)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
)
# sac:
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', type=int, default=1)
parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--rew-norm', action="store_true", default=False)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument(
"--save-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl"
)
args = parser.parse_known_args()[0]
return args
def gather_data():
"""Return expert buffer data."""
args = get_args()
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)]
)
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
net,
args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True,
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
tau=args.tau,
gamma=args.gamma,
alpha=args.alpha,
reward_normalization=args.rew_norm,
estimation_step=args.n_step,
action_space=env.action_space,
)
# collector
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
update_per_step=args.update_per_step,
save_fn=save_fn,
stop_fn=stop_fn,
logger=logger,
)
train_collector.reset()
result = train_collector.collect(n_step=args.buffer_size)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
pickle.dump(buffer, open(args.save_buffer_name, "wb"))
return buffer

221
test/offline/test_bcq.py Normal file
View File

@ -0,0 +1,221 @@
import argparse
import datetime
import os
import pickle
import pprint
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import BCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
if __name__ == "__main__":
from gather_pendulum_data import gather_data
else: # pytest
from test.offline.gather_pendulum_data import gather_data
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[200, 150])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--epoch', type=int, default=7)
parser.add_argument('--step-per-epoch', type=int, default=2000)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[375, 375])
# default to 2 * action_dim
parser.add_argument('--latent_dim', type=int, default=None)
parser.add_argument("--gamma", default=0.99)
parser.add_argument("--tau", default=0.005)
# Weighting for Clipped Double Q-learning in BCQ
parser.add_argument("--lmbda", default=0.75)
# Max perturbation hyper-parameter for BCQ
parser.add_argument("--phi", default=0.05)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only',
)
parser.add_argument(
"--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl"
)
args = parser.parse_known_args()[0]
return args
def test_bcq(args=get_args()):
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -800 # too low?
args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
# perturbation network
net_a = MLP(
input_dim=args.state_dim + args.action_dim,
output_dim=args.action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Perturbation(
net_a, max_action=args.max_action, device=args.device, phi=args.phi
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
# vae
# output_dim = 0, so the last Module in the encoder is ReLU
vae_encoder = MLP(
input_dim=args.state_dim + args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
if not args.latent_dim:
args.latent_dim = args.action_dim * 2
vae_decoder = MLP(
input_dim=args.state_dim + args.latent_dim,
output_dim=args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
vae = VAE(
vae_encoder,
vae_decoder,
hidden_dim=args.vae_hidden_sizes[-1],
latent_dim=args.latent_dim,
max_action=args.max_action,
device=args.device,
).to(args.device)
vae_optim = torch.optim.Adam(vae.parameters())
policy = BCQPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
vae,
vae_optim,
device=args.device,
gamma=args.gamma,
tau=args.tau,
lmbda=args.lmbda,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
# buffer has been gathered
# train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq'
log_path = os.path.join(args.logdir, args.task, 'bcq', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def watch():
policy.load_state_dict(
torch.load(
os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu')
)
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
# trainer
result = offline_trainer(
policy,
buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_fn=save_fn,
stop_fn=stop_fn,
logger=logger,
)
assert stop_fn(result['best_reward'])
# Let's watch its performance!
if __name__ == '__main__':
pprint.pprint(result)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == '__main__':
test_bcq()

View File

@ -19,6 +19,7 @@ from tianshou.policy.modelfree.td3 import TD3Policy
from tianshou.policy.modelfree.sac import SACPolicy
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.imitation.bcq import BCQPolicy
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
@ -44,6 +45,7 @@ __all__ = [
"SACPolicy",
"DiscreteSACPolicy",
"ImitationPolicy",
"BCQPolicy",
"DiscreteBCQPolicy",
"DiscreteCQLPolicy",
"DiscreteCRRPolicy",

View File

@ -0,0 +1,213 @@
import copy
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from tianshou.data import Batch, to_torch
from tianshou.policy import BasePolicy
from tianshou.utils.net.continuous import VAE
class BCQPolicy(BasePolicy):
"""Implementation of BCQ algorithm. arXiv:1812.02900.
:param Perturbation actor: the actor perturbation. (s, a -> perturbed a)
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
critic network.
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param VAE vae: the VAE network, generating actions similar
to those in batch. (s, a -> generated a)
:param torch.optim.Optimizer vae_optim: the optimizer for the VAE network.
:param Union[str, torch.device] device: which device to create this model on.
Default to "cpu".
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param float tau: param for soft update of the target network.
Default to 0.005.
:param float lmbda: param for Clipped Double Q-learning. Default to 0.75.
:param int forward_sampled_times: the number of sampled actions in forward
function. The policy samples many actions and takes the action with the
max value. Default to 100.
:param int num_sampled_action: the number of sampled actions in calculating
target Q. The algorithm samples several actions using VAE, and perturbs
each action to get the target Q. Default to 10.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
vae: VAE,
vae_optim: torch.optim.Optimizer,
device: Union[str, torch.device] = "cpu",
gamma: float = 0.99,
tau: float = 0.005,
lmbda: float = 0.75,
forward_sampled_times: int = 100,
num_sampled_action: int = 10,
**kwargs: Any
) -> None:
# actor is Perturbation!
super().__init__(**kwargs)
self.actor = actor
self.actor_target = copy.deepcopy(self.actor)
self.actor_optim = actor_optim
self.critic1 = critic1
self.critic1_target = copy.deepcopy(self.critic1)
self.critic1_optim = critic1_optim
self.critic2 = critic2
self.critic2_target = copy.deepcopy(self.critic2)
self.critic2_optim = critic2_optim
self.vae = vae
self.vae_optim = vae_optim
self.gamma = gamma
self.tau = tau
self.lmbda = lmbda
self.device = device
self.forward_sampled_times = forward_sampled_times
self.num_sampled_action = num_sampled_action
def train(self, mode: bool = True) -> "BCQPolicy":
"""Set the module in training mode, except for the target network."""
self.training = mode
self.actor.train(mode)
self.critic1.train(mode)
self.critic2.train(mode)
return self
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data."""
# There is "obs" in the Batch
# obs_group: several groups. Each group has a state.
obs_group: torch.Tensor = to_torch( # type: ignore
batch.obs, device=self.device
)
act = []
for obs in obs_group:
# now obs is (state_dim)
obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1)
# now obs is (forward_sampled_times, state_dim)
# decode(obs) generates action and actor perturbs it
action = self.actor(obs, self.vae.decode(obs))
# now action is (forward_sampled_times, action_dim)
q1 = self.critic1(obs, action)
# q1 is (forward_sampled_times, 1)
ind = q1.argmax(0)
act.append(action[ind].cpu().data.numpy().flatten())
act = np.array(act)
return Batch(act=act)
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
for net, net_target in [
[self.critic1, self.critic1_target], [self.critic2, self.critic2_target],
[self.actor, self.actor_target]
]:
for param, target_param in zip(net.parameters(), net_target.parameters()):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data
)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# batch: obs, act, rew, done, obs_next. (numpy array)
# (batch_size, state_dim)
batch: Batch = to_torch( # type: ignore
batch, dtype=torch.float, device=self.device
)
obs, act = batch.obs, batch.act
batch_size = obs.shape[0]
# mean, std: (state.shape[0], latent_dim)
recon, mean, std = self.vae(obs, act)
recon_loss = F.mse_loss(act, recon)
# (....) is D_KL( N(mu, sigma) || N(0,1) )
KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean()
vae_loss = recon_loss + KL_loss / 2
self.vae_optim.zero_grad()
vae_loss.backward()
self.vae_optim.step()
# critic training:
with torch.no_grad():
# repeat num_sampled_action times
obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0)
# now obs_next: (num_sampled_action * batch_size, state_dim)
# perturbed action generated by VAE
act_next = self.vae.decode(obs_next)
# now obs_next: (num_sampled_action * batch_size, action_dim)
target_Q1 = self.critic1_target(obs_next, act_next)
target_Q2 = self.critic2_target(obs_next, act_next)
# Clipped Double Q-learning
target_Q = \
self.lmbda * torch.min(target_Q1, target_Q2) + \
(1 - self.lmbda) * torch.max(target_Q1, target_Q2)
# now target_Q: (num_sampled_action * batch_size, 1)
# the max value of Q
target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1)
# now target_Q: (batch_size, 1)
target_Q = \
batch.rew.reshape(-1, 1) + \
(1 - batch.done).reshape(-1, 1) * self.gamma * target_Q
current_Q1 = self.critic1(obs, act)
current_Q2 = self.critic2(obs, act)
critic1_loss = F.mse_loss(current_Q1, target_Q)
critic2_loss = F.mse_loss(current_Q2, target_Q)
self.critic1_optim.zero_grad()
self.critic2_optim.zero_grad()
critic1_loss.backward()
critic2_loss.backward()
self.critic1_optim.step()
self.critic2_optim.step()
sampled_act = self.vae.decode(obs)
perturbed_act = self.actor(obs, sampled_act)
# max
actor_loss = -self.critic1(obs, perturbed_act).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
# update target network
self.sync_weight()
result = {
"loss/actor": actor_loss.item(),
"loss/critic1": critic1_loss.item(),
"loss/critic2": critic2_loss.item(),
"loss/vae": vae_loss.item(),
}
return result

View File

@ -325,3 +325,122 @@ class RecurrentCritic(nn.Module):
s = torch.cat([s, a], dim=1)
s = self.fc2(s)
return s
class Perturbation(nn.Module):
"""Implementation of perturbation network in BCQ algorithm. Given a state and \
action, it can generate perturbed action.
:param torch.nn.Module preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param float max_action: the maximum value of each dimension of action.
:param Union[str, int, torch.device] device: which device to create this model on.
Default to cpu.
:param float phi: max perturbation parameter for BCQ. Default to 0.05.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
"""
def __init__(
self,
preprocess_net: nn.Module,
max_action: float,
device: Union[str, int, torch.device] = "cpu",
phi: float = 0.05
):
# preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim
super(Perturbation, self).__init__()
self.preprocess_net = preprocess_net
self.device = device
self.max_action = max_action
self.phi = phi
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
# preprocess_net
logits = self.preprocess_net(torch.cat([state, action], -1))[0]
a = self.phi * self.max_action * torch.tanh(logits)
# clip to [-max_action, max_action]
return (a + action).clamp(-self.max_action, self.max_action)
class VAE(nn.Module):
"""Implementation of VAE. It models the distribution of action. Given a \
state, it can generate actions similar to those in batch. It is used \
in BCQ algorithm.
:param torch.nn.Module encoder: the encoder in VAE. Its input_dim must be
state_dim + action_dim, and output_dim must be hidden_dim.
:param torch.nn.Module decoder: the decoder in VAE. Its input_dim must be
state_dim + latent_dim, and output_dim must be action_dim.
:param int hidden_dim: the size of the last linear-layer in encoder.
:param int latent_dim: the size of latent layer.
:param float max_action: the maximum value of each dimension of action.
:param Union[str, torch.device] device: which device to create this model on.
Default to "cpu".
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
"""
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
hidden_dim: int,
latent_dim: int,
max_action: float,
device: Union[str, torch.device] = "cpu"
):
super(VAE, self).__init__()
self.encoder = encoder
self.mean = nn.Linear(hidden_dim, latent_dim)
self.log_std = nn.Linear(hidden_dim, latent_dim)
self.decoder = decoder
self.max_action = max_action
self.latent_dim = latent_dim
self.device = device
def forward(
self, state: torch.Tensor, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [state, action] -> z , [state, z] -> action
z = self.encoder(torch.cat([state, action], -1))
# shape of z: (state.shape[:-1], hidden_dim)
mean = self.mean(z)
# Clamped for numerical stability
log_std = self.log_std(z).clamp(-4, 15)
std = torch.exp(log_std)
# shape of mean, std: (state.shape[:-1], latent_dim)
z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
u = self.decode(state, z) # (state.shape[:-1], action_dim)
return u, mean, std
def decode(
self,
state: torch.Tensor,
z: Union[torch.Tensor, None] = None
) -> torch.Tensor:
# decode(state) -> action
if z is None:
# state.shape[0] may be batch_size
# latent vector clipped to [-0.5, 0.5]
z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
.to(self.device).clamp(-0.5, 0.5)
# decode z with state!
return self.max_action * torch.tanh(self.decoder(torch.cat([state, z], -1)))