continuous.Critic: Add flag apply_preprocess_net_to_obs_only to allow the

preprocessing network to be applied to the observations only (without
  the actions concatenated), which is essential for the case where we want
  to reuse the actor's preprocessing network
This commit is contained in:
Dominik Jain 2024-04-29 14:06:32 +02:00
parent 18ed981875
commit 0b494845c9

View File

@ -15,6 +15,7 @@ from tianshou.utils.net.common import (
TLinearLayer,
get_output_dim,
)
from tianshou.utils.pickle import setstate
SIGMA_MIN = -20
SIGMA_MAX = 2
@ -109,6 +110,9 @@ class Critic(CriticBase):
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
:param linear_layer: use this module as linear layer.
:param flatten_input: whether to flatten input data for the last layer.
:param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before
concatenating with the action) - and without the observations being modified in any way beforehand.
This allows the actor's preprocessing network to be reused for the critic.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
@ -122,11 +126,13 @@ class Critic(CriticBase):
preprocess_net_output_dim: int | None = None,
linear_layer: TLinearLayer = nn.Linear,
flatten_input: bool = True,
apply_preprocess_net_to_obs_only: bool = False,
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.output_dim = 1
self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim,
@ -137,6 +143,14 @@ class Critic(CriticBase):
flatten_input=flatten_input,
)
def __setstate__(self, state: dict) -> None:
setstate(
Critic,
self,
state,
new_default_properties={"apply_preprocess_net_to_obs_only": False},
)
def forward(
self,
obs: np.ndarray | torch.Tensor,
@ -148,7 +162,10 @@ class Critic(CriticBase):
obs,
device=self.device,
dtype=torch.float32,
).flatten(1)
)
if self.apply_preprocess_net_to_obs_only:
obs, _ = self.preprocess(obs)
obs = obs.flatten(1)
if act is not None:
act = torch.as_tensor(
act,
@ -156,8 +173,9 @@ class Critic(CriticBase):
dtype=torch.float32,
).flatten(1)
obs = torch.cat([obs, act], dim=1)
values_B, hidden_BH = self.preprocess(obs)
return self.last(values_B)
if not self.apply_preprocess_net_to_obs_only:
obs, _ = self.preprocess(obs)
return self.last(obs)
class ActorProb(BaseActor):