Implement BCQPolicy and offline_bcq example (#480)
This PR implements BCQPolicy, which could be used to train an offline agent in the environment of continuous action space. An experimental result 'halfcheetah-expert-v1' is provided, which is a d4rl environment (for Offline Reinforcement Learning). Example usage is in the examples/offline/offline_bcq.py.
This commit is contained in:
parent
94d3b27db9
commit
5c5a3db94e
@ -36,6 +36,7 @@
|
|||||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||||
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
||||||
- Vanilla Imitation Learning
|
- Vanilla Imitation Learning
|
||||||
|
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
|
||||||
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
||||||
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
|
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||||
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
|
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
|
||||||
|
|||||||
@ -109,6 +109,11 @@ Imitation
|
|||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.policy.BCQPolicy
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
|
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
|
|||||||
@ -27,6 +27,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||||
|
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
|
||||||
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
||||||
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||||
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
||||||
|
|||||||
28
examples/offline/README.md
Normal file
28
examples/offline/README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Offline
|
||||||
|
|
||||||
|
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.
|
||||||
|
|
||||||
|
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
|
||||||
|
|
||||||
|
## Train
|
||||||
|
|
||||||
|
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
|
||||||
|
|
||||||
|
To train an agent with BCQ algorithm:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python offline_bcq.py --task halfcheetah-expert-v1
|
||||||
|
```
|
||||||
|
|
||||||
|
After 1M steps:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment.
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
| Environment | BCQ |
|
||||||
|
| --------------------- | --------------- |
|
||||||
|
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |
|
||||||
|
|
||||||
241
examples/offline/offline_bcq.py
Normal file
241
examples/offline/offline_bcq.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import d4rl
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.policy import BCQPolicy
|
||||||
|
from tianshou.trainer import offline_trainer
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
|
from tianshou.utils.net.common import MLP, Net
|
||||||
|
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='halfcheetah-expert-v1')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=1000000)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[400, 300])
|
||||||
|
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument("--start-timesteps", type=int, default=10000)
|
||||||
|
parser.add_argument('--epoch', type=int, default=200)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=256)
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=1 / 35)
|
||||||
|
|
||||||
|
parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[750, 750])
|
||||||
|
# default to 2 * action_dim
|
||||||
|
parser.add_argument('--latent-dim', type=int)
|
||||||
|
parser.add_argument("--gamma", default=0.99)
|
||||||
|
parser.add_argument("--tau", default=0.005)
|
||||||
|
# Weighting for Clipped Double Q-learning in BCQ
|
||||||
|
parser.add_argument("--lmbda", default=0.75)
|
||||||
|
# Max perturbation hyper-parameter for BCQ
|
||||||
|
parser.add_argument("--phi", default=0.05)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
parser.add_argument('--resume-path', type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
'--watch',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='watch the play of pre-trained policy only',
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bcq():
|
||||||
|
args = get_args()
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
args.max_action = env.action_space.high[0] # float
|
||||||
|
print("device:", args.device)
|
||||||
|
print("Observations shape:", args.state_shape)
|
||||||
|
print("Actions shape:", args.action_shape)
|
||||||
|
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
|
||||||
|
|
||||||
|
args.state_dim = args.state_shape[0]
|
||||||
|
args.action_dim = args.action_shape[0]
|
||||||
|
print("Max_action", args.max_action)
|
||||||
|
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
train_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||||
|
)
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
|
||||||
|
# model
|
||||||
|
# perturbation network
|
||||||
|
net_a = MLP(
|
||||||
|
input_dim=args.state_dim + args.action_dim,
|
||||||
|
output_dim=args.action_dim,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
actor = Perturbation(
|
||||||
|
net_a, max_action=args.max_action, device=args.device, phi=args.phi
|
||||||
|
).to(args.device)
|
||||||
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
|
|
||||||
|
net_c1 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
net_c2 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||||
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
|
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||||
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
|
# vae
|
||||||
|
# output_dim = 0, so the last Module in the encoder is ReLU
|
||||||
|
vae_encoder = MLP(
|
||||||
|
input_dim=args.state_dim + args.action_dim,
|
||||||
|
hidden_sizes=args.vae_hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
if not args.latent_dim:
|
||||||
|
args.latent_dim = args.action_dim * 2
|
||||||
|
vae_decoder = MLP(
|
||||||
|
input_dim=args.state_dim + args.latent_dim,
|
||||||
|
output_dim=args.action_dim,
|
||||||
|
hidden_sizes=args.vae_hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
vae = VAE(
|
||||||
|
vae_encoder,
|
||||||
|
vae_decoder,
|
||||||
|
hidden_dim=args.vae_hidden_sizes[-1],
|
||||||
|
latent_dim=args.latent_dim,
|
||||||
|
max_action=args.max_action,
|
||||||
|
device=args.device,
|
||||||
|
).to(args.device)
|
||||||
|
vae_optim = torch.optim.Adam(vae.parameters())
|
||||||
|
|
||||||
|
policy = BCQPolicy(
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critic1,
|
||||||
|
critic1_optim,
|
||||||
|
critic2,
|
||||||
|
critic2_optim,
|
||||||
|
vae,
|
||||||
|
vae_optim,
|
||||||
|
device=args.device,
|
||||||
|
gamma=args.gamma,
|
||||||
|
tau=args.tau,
|
||||||
|
lmbda=args.lmbda,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
|
||||||
|
# collector
|
||||||
|
if args.training_num > 1:
|
||||||
|
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||||
|
else:
|
||||||
|
buffer = ReplayBuffer(args.buffer_size)
|
||||||
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||||
|
# log
|
||||||
|
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||||
|
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq'
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'bcq', log_file)
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def watch():
|
||||||
|
if args.resume_path is None:
|
||||||
|
args.resume_path = os.path.join(log_path, 'policy.pth')
|
||||||
|
|
||||||
|
policy.load_state_dict(
|
||||||
|
torch.load(args.resume_path, map_location=torch.device('cpu'))
|
||||||
|
)
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
collector.collect(n_episode=1, render=1 / 35)
|
||||||
|
|
||||||
|
if not args.watch:
|
||||||
|
dataset = d4rl.qlearning_dataset(env)
|
||||||
|
dataset_size = dataset['rewards'].size
|
||||||
|
|
||||||
|
print("dataset_size", dataset_size)
|
||||||
|
replay_buffer = ReplayBuffer(dataset_size)
|
||||||
|
|
||||||
|
for i in range(dataset_size):
|
||||||
|
replay_buffer.add(
|
||||||
|
Batch(
|
||||||
|
obs=dataset['observations'][i],
|
||||||
|
act=dataset['actions'][i],
|
||||||
|
rew=dataset['rewards'][i],
|
||||||
|
done=dataset['terminals'][i],
|
||||||
|
obs_next=dataset['next_observations'][i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("dataset loaded")
|
||||||
|
# trainer
|
||||||
|
result = offline_trainer(
|
||||||
|
policy,
|
||||||
|
replay_buffer,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
save_fn=save_fn,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
pprint.pprint(result)
|
||||||
|
else:
|
||||||
|
watch()
|
||||||
|
|
||||||
|
# Let's watch its performance!
|
||||||
|
policy.eval()
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
test_collector.reset()
|
||||||
|
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||||
|
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_bcq()
|
||||||
BIN
examples/offline/results/bcq/halfcheetah-expert-v1_reward.png
Normal file
BIN
examples/offline/results/bcq/halfcheetah-expert-v1_reward.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 28 KiB |
@ -134,7 +134,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
SubprocVectorEnv(env_fns),
|
SubprocVectorEnv(env_fns),
|
||||||
ShmemVectorEnv(env_fns),
|
ShmemVectorEnv(env_fns),
|
||||||
]
|
]
|
||||||
if has_ray():
|
if has_ray() and sys.platform == "linux":
|
||||||
venv += [RayVectorEnv(env_fns)]
|
venv += [RayVectorEnv(env_fns)]
|
||||||
for v in venv:
|
for v in venv:
|
||||||
v.seed(0)
|
v.seed(0)
|
||||||
|
|||||||
0
test/offline/__init__.py
Normal file
0
test/offline/__init__.py
Normal file
170
test/offline/gather_pendulum_data.py
Normal file
170
test/offline/gather_pendulum_data.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.policy import SACPolicy
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=200000)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128])
|
||||||
|
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--epoch', type=int, default=7)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=8000)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=256)
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.125)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
|
||||||
|
parser.add_argument("--gamma", default=0.99)
|
||||||
|
parser.add_argument("--tau", default=0.005)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
parser.add_argument('--resume-path', type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
'--watch',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='watch the play of pre-trained policy only'
|
||||||
|
)
|
||||||
|
# sac:
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.2)
|
||||||
|
parser.add_argument('--auto-alpha', type=int, default=1)
|
||||||
|
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||||
|
parser.add_argument('--rew-norm', action="store_true", default=False)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl"
|
||||||
|
)
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def gather_data():
|
||||||
|
"""Return expert buffer data."""
|
||||||
|
args = get_args()
|
||||||
|
env = gym.make(args.task)
|
||||||
|
if args.task == 'Pendulum-v0':
|
||||||
|
env.spec.reward_threshold = -250
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
args.max_action = env.action_space.high[0]
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
train_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||||
|
)
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||||
|
actor = ActorProb(
|
||||||
|
net,
|
||||||
|
args.action_shape,
|
||||||
|
max_action=args.max_action,
|
||||||
|
device=args.device,
|
||||||
|
unbounded=True,
|
||||||
|
).to(args.device)
|
||||||
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
|
net_c1 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||||
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
|
net_c2 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||||
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
|
if args.auto_alpha:
|
||||||
|
target_entropy = -np.prod(env.action_space.shape)
|
||||||
|
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||||
|
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||||
|
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||||
|
|
||||||
|
policy = SACPolicy(
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critic1,
|
||||||
|
critic1_optim,
|
||||||
|
critic2,
|
||||||
|
critic2_optim,
|
||||||
|
tau=args.tau,
|
||||||
|
gamma=args.gamma,
|
||||||
|
alpha=args.alpha,
|
||||||
|
reward_normalization=args.rew_norm,
|
||||||
|
estimation_step=args.n_step,
|
||||||
|
action_space=env.action_space,
|
||||||
|
)
|
||||||
|
# collector
|
||||||
|
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||||
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
offpolicy_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.step_per_collect,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
update_per_step=args.update_per_step,
|
||||||
|
save_fn=save_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
train_collector.reset()
|
||||||
|
result = train_collector.collect(n_step=args.buffer_size)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
pickle.dump(buffer, open(args.save_buffer_name, "wb"))
|
||||||
|
return buffer
|
||||||
221
test/offline/test_bcq.py
Normal file
221
test/offline/test_bcq.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.policy import BCQPolicy
|
||||||
|
from tianshou.trainer import offline_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import MLP, Net
|
||||||
|
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from gather_pendulum_data import gather_data
|
||||||
|
else: # pytest
|
||||||
|
from test.offline.gather_pendulum_data import gather_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[200, 150])
|
||||||
|
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--epoch', type=int, default=7)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=2000)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=256)
|
||||||
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
|
||||||
|
parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[375, 375])
|
||||||
|
# default to 2 * action_dim
|
||||||
|
parser.add_argument('--latent_dim', type=int, default=None)
|
||||||
|
parser.add_argument("--gamma", default=0.99)
|
||||||
|
parser.add_argument("--tau", default=0.005)
|
||||||
|
# Weighting for Clipped Double Q-learning in BCQ
|
||||||
|
parser.add_argument("--lmbda", default=0.75)
|
||||||
|
# Max perturbation hyper-parameter for BCQ
|
||||||
|
parser.add_argument("--phi", default=0.05)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
parser.add_argument('--resume-path', type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
'--watch',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='watch the play of pre-trained policy only',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl"
|
||||||
|
)
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_bcq(args=get_args()):
|
||||||
|
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||||
|
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||||
|
else:
|
||||||
|
buffer = gather_data()
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
args.max_action = env.action_space.high[0] # float
|
||||||
|
if args.task == 'Pendulum-v0':
|
||||||
|
env.spec.reward_threshold = -800 # too low?
|
||||||
|
|
||||||
|
args.state_dim = args.state_shape[0]
|
||||||
|
args.action_dim = args.action_shape[0]
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||||
|
)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
|
||||||
|
# model
|
||||||
|
# perturbation network
|
||||||
|
net_a = MLP(
|
||||||
|
input_dim=args.state_dim + args.action_dim,
|
||||||
|
output_dim=args.action_dim,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
actor = Perturbation(
|
||||||
|
net_a, max_action=args.max_action, device=args.device, phi=args.phi
|
||||||
|
).to(args.device)
|
||||||
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
|
|
||||||
|
net_c1 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
net_c2 = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
concat=True,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||||
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
|
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||||
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
|
# vae
|
||||||
|
# output_dim = 0, so the last Module in the encoder is ReLU
|
||||||
|
vae_encoder = MLP(
|
||||||
|
input_dim=args.state_dim + args.action_dim,
|
||||||
|
hidden_sizes=args.vae_hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
if not args.latent_dim:
|
||||||
|
args.latent_dim = args.action_dim * 2
|
||||||
|
vae_decoder = MLP(
|
||||||
|
input_dim=args.state_dim + args.latent_dim,
|
||||||
|
output_dim=args.action_dim,
|
||||||
|
hidden_sizes=args.vae_hidden_sizes,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
vae = VAE(
|
||||||
|
vae_encoder,
|
||||||
|
vae_decoder,
|
||||||
|
hidden_dim=args.vae_hidden_sizes[-1],
|
||||||
|
latent_dim=args.latent_dim,
|
||||||
|
max_action=args.max_action,
|
||||||
|
device=args.device,
|
||||||
|
).to(args.device)
|
||||||
|
vae_optim = torch.optim.Adam(vae.parameters())
|
||||||
|
|
||||||
|
policy = BCQPolicy(
|
||||||
|
actor,
|
||||||
|
actor_optim,
|
||||||
|
critic1,
|
||||||
|
critic1_optim,
|
||||||
|
critic2,
|
||||||
|
critic2_optim,
|
||||||
|
vae,
|
||||||
|
vae_optim,
|
||||||
|
device=args.device,
|
||||||
|
gamma=args.gamma,
|
||||||
|
tau=args.tau,
|
||||||
|
lmbda=args.lmbda,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
|
||||||
|
# collector
|
||||||
|
# buffer has been gathered
|
||||||
|
# train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
# log
|
||||||
|
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||||
|
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq'
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'bcq', log_file)
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
def watch():
|
||||||
|
policy.load_state_dict(
|
||||||
|
torch.load(
|
||||||
|
os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
collector.collect(n_episode=1, render=1 / 35)
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offline_trainer(
|
||||||
|
policy,
|
||||||
|
buffer,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
save_fn=save_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
# Let's watch its performance!
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
env = gym.make(args.task)
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_bcq()
|
||||||
@ -19,6 +19,7 @@ from tianshou.policy.modelfree.td3 import TD3Policy
|
|||||||
from tianshou.policy.modelfree.sac import SACPolicy
|
from tianshou.policy.modelfree.sac import SACPolicy
|
||||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||||
from tianshou.policy.imitation.base import ImitationPolicy
|
from tianshou.policy.imitation.base import ImitationPolicy
|
||||||
|
from tianshou.policy.imitation.bcq import BCQPolicy
|
||||||
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||||
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
||||||
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
||||||
@ -44,6 +45,7 @@ __all__ = [
|
|||||||
"SACPolicy",
|
"SACPolicy",
|
||||||
"DiscreteSACPolicy",
|
"DiscreteSACPolicy",
|
||||||
"ImitationPolicy",
|
"ImitationPolicy",
|
||||||
|
"BCQPolicy",
|
||||||
"DiscreteBCQPolicy",
|
"DiscreteBCQPolicy",
|
||||||
"DiscreteCQLPolicy",
|
"DiscreteCQLPolicy",
|
||||||
"DiscreteCRRPolicy",
|
"DiscreteCRRPolicy",
|
||||||
|
|||||||
213
tianshou/policy/imitation/bcq.py
Normal file
213
tianshou/policy/imitation/bcq.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tianshou.data import Batch, to_torch
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
from tianshou.utils.net.continuous import VAE
|
||||||
|
|
||||||
|
|
||||||
|
class BCQPolicy(BasePolicy):
|
||||||
|
"""Implementation of BCQ algorithm. arXiv:1812.02900.
|
||||||
|
|
||||||
|
:param Perturbation actor: the actor perturbation. (s, a -> perturbed a)
|
||||||
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||||
|
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
|
||||||
|
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
||||||
|
critic network.
|
||||||
|
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
|
||||||
|
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||||
|
critic network.
|
||||||
|
:param VAE vae: the VAE network, generating actions similar
|
||||||
|
to those in batch. (s, a -> generated a)
|
||||||
|
:param torch.optim.Optimizer vae_optim: the optimizer for the VAE network.
|
||||||
|
:param Union[str, torch.device] device: which device to create this model on.
|
||||||
|
Default to "cpu".
|
||||||
|
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
||||||
|
:param float tau: param for soft update of the target network.
|
||||||
|
Default to 0.005.
|
||||||
|
:param float lmbda: param for Clipped Double Q-learning. Default to 0.75.
|
||||||
|
:param int forward_sampled_times: the number of sampled actions in forward
|
||||||
|
function. The policy samples many actions and takes the action with the
|
||||||
|
max value. Default to 100.
|
||||||
|
:param int num_sampled_action: the number of sampled actions in calculating
|
||||||
|
target Q. The algorithm samples several actions using VAE, and perturbs
|
||||||
|
each action to get the target Q. Default to 10.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||||
|
explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor: torch.nn.Module,
|
||||||
|
actor_optim: torch.optim.Optimizer,
|
||||||
|
critic1: torch.nn.Module,
|
||||||
|
critic1_optim: torch.optim.Optimizer,
|
||||||
|
critic2: torch.nn.Module,
|
||||||
|
critic2_optim: torch.optim.Optimizer,
|
||||||
|
vae: VAE,
|
||||||
|
vae_optim: torch.optim.Optimizer,
|
||||||
|
device: Union[str, torch.device] = "cpu",
|
||||||
|
gamma: float = 0.99,
|
||||||
|
tau: float = 0.005,
|
||||||
|
lmbda: float = 0.75,
|
||||||
|
forward_sampled_times: int = 100,
|
||||||
|
num_sampled_action: int = 10,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
# actor is Perturbation!
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.actor = actor
|
||||||
|
self.actor_target = copy.deepcopy(self.actor)
|
||||||
|
self.actor_optim = actor_optim
|
||||||
|
|
||||||
|
self.critic1 = critic1
|
||||||
|
self.critic1_target = copy.deepcopy(self.critic1)
|
||||||
|
self.critic1_optim = critic1_optim
|
||||||
|
|
||||||
|
self.critic2 = critic2
|
||||||
|
self.critic2_target = copy.deepcopy(self.critic2)
|
||||||
|
self.critic2_optim = critic2_optim
|
||||||
|
|
||||||
|
self.vae = vae
|
||||||
|
self.vae_optim = vae_optim
|
||||||
|
|
||||||
|
self.gamma = gamma
|
||||||
|
self.tau = tau
|
||||||
|
self.lmbda = lmbda
|
||||||
|
self.device = device
|
||||||
|
self.forward_sampled_times = forward_sampled_times
|
||||||
|
self.num_sampled_action = num_sampled_action
|
||||||
|
|
||||||
|
def train(self, mode: bool = True) -> "BCQPolicy":
|
||||||
|
"""Set the module in training mode, except for the target network."""
|
||||||
|
self.training = mode
|
||||||
|
self.actor.train(mode)
|
||||||
|
self.critic1.train(mode)
|
||||||
|
self.critic2.train(mode)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: Batch,
|
||||||
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Batch:
|
||||||
|
"""Compute action over the given batch data."""
|
||||||
|
# There is "obs" in the Batch
|
||||||
|
# obs_group: several groups. Each group has a state.
|
||||||
|
obs_group: torch.Tensor = to_torch( # type: ignore
|
||||||
|
batch.obs, device=self.device
|
||||||
|
)
|
||||||
|
act = []
|
||||||
|
for obs in obs_group:
|
||||||
|
# now obs is (state_dim)
|
||||||
|
obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1)
|
||||||
|
# now obs is (forward_sampled_times, state_dim)
|
||||||
|
|
||||||
|
# decode(obs) generates action and actor perturbs it
|
||||||
|
action = self.actor(obs, self.vae.decode(obs))
|
||||||
|
# now action is (forward_sampled_times, action_dim)
|
||||||
|
q1 = self.critic1(obs, action)
|
||||||
|
# q1 is (forward_sampled_times, 1)
|
||||||
|
ind = q1.argmax(0)
|
||||||
|
act.append(action[ind].cpu().data.numpy().flatten())
|
||||||
|
act = np.array(act)
|
||||||
|
return Batch(act=act)
|
||||||
|
|
||||||
|
def sync_weight(self) -> None:
|
||||||
|
"""Soft-update the weight for the target network."""
|
||||||
|
for net, net_target in [
|
||||||
|
[self.critic1, self.critic1_target], [self.critic2, self.critic2_target],
|
||||||
|
[self.actor, self.actor_target]
|
||||||
|
]:
|
||||||
|
for param, target_param in zip(net.parameters(), net_target.parameters()):
|
||||||
|
target_param.data.copy_(
|
||||||
|
self.tau * param.data + (1 - self.tau) * target_param.data
|
||||||
|
)
|
||||||
|
|
||||||
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||||
|
# batch: obs, act, rew, done, obs_next. (numpy array)
|
||||||
|
# (batch_size, state_dim)
|
||||||
|
batch: Batch = to_torch( # type: ignore
|
||||||
|
batch, dtype=torch.float, device=self.device
|
||||||
|
)
|
||||||
|
obs, act = batch.obs, batch.act
|
||||||
|
batch_size = obs.shape[0]
|
||||||
|
|
||||||
|
# mean, std: (state.shape[0], latent_dim)
|
||||||
|
recon, mean, std = self.vae(obs, act)
|
||||||
|
recon_loss = F.mse_loss(act, recon)
|
||||||
|
# (....) is D_KL( N(mu, sigma) || N(0,1) )
|
||||||
|
KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean()
|
||||||
|
vae_loss = recon_loss + KL_loss / 2
|
||||||
|
|
||||||
|
self.vae_optim.zero_grad()
|
||||||
|
vae_loss.backward()
|
||||||
|
self.vae_optim.step()
|
||||||
|
|
||||||
|
# critic training:
|
||||||
|
with torch.no_grad():
|
||||||
|
# repeat num_sampled_action times
|
||||||
|
obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0)
|
||||||
|
# now obs_next: (num_sampled_action * batch_size, state_dim)
|
||||||
|
|
||||||
|
# perturbed action generated by VAE
|
||||||
|
act_next = self.vae.decode(obs_next)
|
||||||
|
# now obs_next: (num_sampled_action * batch_size, action_dim)
|
||||||
|
target_Q1 = self.critic1_target(obs_next, act_next)
|
||||||
|
target_Q2 = self.critic2_target(obs_next, act_next)
|
||||||
|
|
||||||
|
# Clipped Double Q-learning
|
||||||
|
target_Q = \
|
||||||
|
self.lmbda * torch.min(target_Q1, target_Q2) + \
|
||||||
|
(1 - self.lmbda) * torch.max(target_Q1, target_Q2)
|
||||||
|
# now target_Q: (num_sampled_action * batch_size, 1)
|
||||||
|
|
||||||
|
# the max value of Q
|
||||||
|
target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1)
|
||||||
|
# now target_Q: (batch_size, 1)
|
||||||
|
|
||||||
|
target_Q = \
|
||||||
|
batch.rew.reshape(-1, 1) + \
|
||||||
|
(1 - batch.done).reshape(-1, 1) * self.gamma * target_Q
|
||||||
|
|
||||||
|
current_Q1 = self.critic1(obs, act)
|
||||||
|
current_Q2 = self.critic2(obs, act)
|
||||||
|
|
||||||
|
critic1_loss = F.mse_loss(current_Q1, target_Q)
|
||||||
|
critic2_loss = F.mse_loss(current_Q2, target_Q)
|
||||||
|
|
||||||
|
self.critic1_optim.zero_grad()
|
||||||
|
self.critic2_optim.zero_grad()
|
||||||
|
critic1_loss.backward()
|
||||||
|
critic2_loss.backward()
|
||||||
|
self.critic1_optim.step()
|
||||||
|
self.critic2_optim.step()
|
||||||
|
|
||||||
|
sampled_act = self.vae.decode(obs)
|
||||||
|
perturbed_act = self.actor(obs, sampled_act)
|
||||||
|
|
||||||
|
# max
|
||||||
|
actor_loss = -self.critic1(obs, perturbed_act).mean()
|
||||||
|
|
||||||
|
self.actor_optim.zero_grad()
|
||||||
|
actor_loss.backward()
|
||||||
|
self.actor_optim.step()
|
||||||
|
|
||||||
|
# update target network
|
||||||
|
self.sync_weight()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"loss/actor": actor_loss.item(),
|
||||||
|
"loss/critic1": critic1_loss.item(),
|
||||||
|
"loss/critic2": critic2_loss.item(),
|
||||||
|
"loss/vae": vae_loss.item(),
|
||||||
|
}
|
||||||
|
return result
|
||||||
@ -325,3 +325,122 @@ class RecurrentCritic(nn.Module):
|
|||||||
s = torch.cat([s, a], dim=1)
|
s = torch.cat([s, a], dim=1)
|
||||||
s = self.fc2(s)
|
s = self.fc2(s)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class Perturbation(nn.Module):
|
||||||
|
"""Implementation of perturbation network in BCQ algorithm. Given a state and \
|
||||||
|
action, it can generate perturbed action.
|
||||||
|
|
||||||
|
:param torch.nn.Module preprocess_net: a self-defined preprocess_net which output a
|
||||||
|
flattened hidden state.
|
||||||
|
:param float max_action: the maximum value of each dimension of action.
|
||||||
|
:param Union[str, int, torch.device] device: which device to create this model on.
|
||||||
|
Default to cpu.
|
||||||
|
:param float phi: max perturbation parameter for BCQ. Default to 0.05.
|
||||||
|
|
||||||
|
For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
preprocess_net: nn.Module,
|
||||||
|
max_action: float,
|
||||||
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
|
phi: float = 0.05
|
||||||
|
):
|
||||||
|
# preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim
|
||||||
|
super(Perturbation, self).__init__()
|
||||||
|
self.preprocess_net = preprocess_net
|
||||||
|
self.device = device
|
||||||
|
self.max_action = max_action
|
||||||
|
self.phi = phi
|
||||||
|
|
||||||
|
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
# preprocess_net
|
||||||
|
logits = self.preprocess_net(torch.cat([state, action], -1))[0]
|
||||||
|
a = self.phi * self.max_action * torch.tanh(logits)
|
||||||
|
# clip to [-max_action, max_action]
|
||||||
|
return (a + action).clamp(-self.max_action, self.max_action)
|
||||||
|
|
||||||
|
|
||||||
|
class VAE(nn.Module):
|
||||||
|
"""Implementation of VAE. It models the distribution of action. Given a \
|
||||||
|
state, it can generate actions similar to those in batch. It is used \
|
||||||
|
in BCQ algorithm.
|
||||||
|
|
||||||
|
:param torch.nn.Module encoder: the encoder in VAE. Its input_dim must be
|
||||||
|
state_dim + action_dim, and output_dim must be hidden_dim.
|
||||||
|
:param torch.nn.Module decoder: the decoder in VAE. Its input_dim must be
|
||||||
|
state_dim + latent_dim, and output_dim must be action_dim.
|
||||||
|
:param int hidden_dim: the size of the last linear-layer in encoder.
|
||||||
|
:param int latent_dim: the size of latent layer.
|
||||||
|
:param float max_action: the maximum value of each dimension of action.
|
||||||
|
:param Union[str, torch.device] device: which device to create this model on.
|
||||||
|
Default to "cpu".
|
||||||
|
|
||||||
|
For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
You can refer to `examples/offline/offline_bcq.py` to see how to use it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: nn.Module,
|
||||||
|
decoder: nn.Module,
|
||||||
|
hidden_dim: int,
|
||||||
|
latent_dim: int,
|
||||||
|
max_action: float,
|
||||||
|
device: Union[str, torch.device] = "cpu"
|
||||||
|
):
|
||||||
|
super(VAE, self).__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
|
||||||
|
self.mean = nn.Linear(hidden_dim, latent_dim)
|
||||||
|
self.log_std = nn.Linear(hidden_dim, latent_dim)
|
||||||
|
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
self.max_action = max_action
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, state: torch.Tensor, action: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
# [state, action] -> z , [state, z] -> action
|
||||||
|
z = self.encoder(torch.cat([state, action], -1))
|
||||||
|
# shape of z: (state.shape[:-1], hidden_dim)
|
||||||
|
|
||||||
|
mean = self.mean(z)
|
||||||
|
# Clamped for numerical stability
|
||||||
|
log_std = self.log_std(z).clamp(-4, 15)
|
||||||
|
std = torch.exp(log_std)
|
||||||
|
# shape of mean, std: (state.shape[:-1], latent_dim)
|
||||||
|
|
||||||
|
z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim)
|
||||||
|
|
||||||
|
u = self.decode(state, z) # (state.shape[:-1], action_dim)
|
||||||
|
return u, mean, std
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
state: torch.Tensor,
|
||||||
|
z: Union[torch.Tensor, None] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# decode(state) -> action
|
||||||
|
if z is None:
|
||||||
|
# state.shape[0] may be batch_size
|
||||||
|
# latent vector clipped to [-0.5, 0.5]
|
||||||
|
z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \
|
||||||
|
.to(self.device).clamp(-0.5, 0.5)
|
||||||
|
|
||||||
|
# decode z with state!
|
||||||
|
return self.max_action * torch.tanh(self.decoder(torch.cat([state, z], -1)))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user