207 lines
7.7 KiB
Python
207 lines
7.7 KiB
Python
import copy
|
|
from typing import Any, Dict, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from tianshou.data import Batch, to_torch
|
|
from tianshou.policy import BasePolicy
|
|
from tianshou.utils.net.continuous import VAE
|
|
|
|
|
|
class BCQPolicy(BasePolicy):
|
|
"""Implementation of BCQ algorithm. arXiv:1812.02900.
|
|
|
|
:param Perturbation actor: the actor perturbation. (s, a -> perturbed a)
|
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
|
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
|
|
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
|
critic network.
|
|
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
|
|
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
|
critic network.
|
|
:param VAE vae: the VAE network, generating actions similar
|
|
to those in batch. (s, a -> generated a)
|
|
:param torch.optim.Optimizer vae_optim: the optimizer for the VAE network.
|
|
:param Union[str, torch.device] device: which device to create this model on.
|
|
Default to "cpu".
|
|
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
|
:param float tau: param for soft update of the target network.
|
|
Default to 0.005.
|
|
:param float lmbda: param for Clipped Double Q-learning. Default to 0.75.
|
|
:param int forward_sampled_times: the number of sampled actions in forward
|
|
function. The policy samples many actions and takes the action with the
|
|
max value. Default to 100.
|
|
:param int num_sampled_action: the number of sampled actions in calculating
|
|
target Q. The algorithm samples several actions using VAE, and perturbs
|
|
each action to get the target Q. Default to 10.
|
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
actor: torch.nn.Module,
|
|
actor_optim: torch.optim.Optimizer,
|
|
critic1: torch.nn.Module,
|
|
critic1_optim: torch.optim.Optimizer,
|
|
critic2: torch.nn.Module,
|
|
critic2_optim: torch.optim.Optimizer,
|
|
vae: VAE,
|
|
vae_optim: torch.optim.Optimizer,
|
|
device: Union[str, torch.device] = "cpu",
|
|
gamma: float = 0.99,
|
|
tau: float = 0.005,
|
|
lmbda: float = 0.75,
|
|
forward_sampled_times: int = 100,
|
|
num_sampled_action: int = 10,
|
|
**kwargs: Any
|
|
) -> None:
|
|
# actor is Perturbation!
|
|
super().__init__(**kwargs)
|
|
self.actor = actor
|
|
self.actor_target = copy.deepcopy(self.actor)
|
|
self.actor_optim = actor_optim
|
|
|
|
self.critic1 = critic1
|
|
self.critic1_target = copy.deepcopy(self.critic1)
|
|
self.critic1_optim = critic1_optim
|
|
|
|
self.critic2 = critic2
|
|
self.critic2_target = copy.deepcopy(self.critic2)
|
|
self.critic2_optim = critic2_optim
|
|
|
|
self.vae = vae
|
|
self.vae_optim = vae_optim
|
|
|
|
self.gamma = gamma
|
|
self.tau = tau
|
|
self.lmbda = lmbda
|
|
self.device = device
|
|
self.forward_sampled_times = forward_sampled_times
|
|
self.num_sampled_action = num_sampled_action
|
|
|
|
def train(self, mode: bool = True) -> "BCQPolicy":
|
|
"""Set the module in training mode, except for the target network."""
|
|
self.training = mode
|
|
self.actor.train(mode)
|
|
self.critic1.train(mode)
|
|
self.critic2.train(mode)
|
|
return self
|
|
|
|
def forward(
|
|
self,
|
|
batch: Batch,
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
**kwargs: Any,
|
|
) -> Batch:
|
|
"""Compute action over the given batch data."""
|
|
# There is "obs" in the Batch
|
|
# obs_group: several groups. Each group has a state.
|
|
obs_group: torch.Tensor = to_torch(batch.obs, device=self.device)
|
|
act_group = []
|
|
for obs in obs_group:
|
|
# now obs is (state_dim)
|
|
obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1)
|
|
# now obs is (forward_sampled_times, state_dim)
|
|
|
|
# decode(obs) generates action and actor perturbs it
|
|
act = self.actor(obs, self.vae.decode(obs))
|
|
# now action is (forward_sampled_times, action_dim)
|
|
q1 = self.critic1(obs, act)
|
|
# q1 is (forward_sampled_times, 1)
|
|
max_indice = q1.argmax(0)
|
|
act_group.append(act[max_indice].cpu().data.numpy().flatten())
|
|
act_group = np.array(act_group)
|
|
return Batch(act=act_group)
|
|
|
|
def sync_weight(self) -> None:
|
|
"""Soft-update the weight for the target network."""
|
|
self.soft_update(self.critic1_target, self.critic1, self.tau)
|
|
self.soft_update(self.critic2_target, self.critic2, self.tau)
|
|
self.soft_update(self.actor_target, self.actor, self.tau)
|
|
|
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
|
# batch: obs, act, rew, done, obs_next. (numpy array)
|
|
# (batch_size, state_dim)
|
|
batch: Batch = to_torch(batch, dtype=torch.float, device=self.device)
|
|
obs, act = batch.obs, batch.act
|
|
batch_size = obs.shape[0]
|
|
|
|
# mean, std: (state.shape[0], latent_dim)
|
|
recon, mean, std = self.vae(obs, act)
|
|
recon_loss = F.mse_loss(act, recon)
|
|
# (....) is D_KL( N(mu, sigma) || N(0,1) )
|
|
KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean()
|
|
vae_loss = recon_loss + KL_loss / 2
|
|
|
|
self.vae_optim.zero_grad()
|
|
vae_loss.backward()
|
|
self.vae_optim.step()
|
|
|
|
# critic training:
|
|
with torch.no_grad():
|
|
# repeat num_sampled_action times
|
|
obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0)
|
|
# now obs_next: (num_sampled_action * batch_size, state_dim)
|
|
|
|
# perturbed action generated by VAE
|
|
act_next = self.vae.decode(obs_next)
|
|
# now obs_next: (num_sampled_action * batch_size, action_dim)
|
|
target_Q1 = self.critic1_target(obs_next, act_next)
|
|
target_Q2 = self.critic2_target(obs_next, act_next)
|
|
|
|
# Clipped Double Q-learning
|
|
target_Q = \
|
|
self.lmbda * torch.min(target_Q1, target_Q2) + \
|
|
(1 - self.lmbda) * torch.max(target_Q1, target_Q2)
|
|
# now target_Q: (num_sampled_action * batch_size, 1)
|
|
|
|
# the max value of Q
|
|
target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1)
|
|
# now target_Q: (batch_size, 1)
|
|
|
|
target_Q = \
|
|
batch.rew.reshape(-1, 1) + \
|
|
(1 - batch.done).reshape(-1, 1) * self.gamma * target_Q
|
|
|
|
current_Q1 = self.critic1(obs, act)
|
|
current_Q2 = self.critic2(obs, act)
|
|
|
|
critic1_loss = F.mse_loss(current_Q1, target_Q)
|
|
critic2_loss = F.mse_loss(current_Q2, target_Q)
|
|
|
|
self.critic1_optim.zero_grad()
|
|
self.critic2_optim.zero_grad()
|
|
critic1_loss.backward()
|
|
critic2_loss.backward()
|
|
self.critic1_optim.step()
|
|
self.critic2_optim.step()
|
|
|
|
sampled_act = self.vae.decode(obs)
|
|
perturbed_act = self.actor(obs, sampled_act)
|
|
|
|
# max
|
|
actor_loss = -self.critic1(obs, perturbed_act).mean()
|
|
|
|
self.actor_optim.zero_grad()
|
|
actor_loss.backward()
|
|
self.actor_optim.step()
|
|
|
|
# update target network
|
|
self.sync_weight()
|
|
|
|
result = {
|
|
"loss/actor": actor_loss.item(),
|
|
"loss/critic1": critic1_loss.item(),
|
|
"loss/critic2": critic2_loss.item(),
|
|
"loss/vae": vae_loss.item(),
|
|
}
|
|
return result
|