From 0b494845c915b944c01751f8b6c737f0d75d37a6 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 29 Apr 2024 14:06:32 +0200 Subject: [PATCH] continuous.Critic: Add flag apply_preprocess_net_to_obs_only to allow the preprocessing network to be applied to the observations only (without the actions concatenated), which is essential for the case where we want to reuse the actor's preprocessing network --- tianshou/utils/net/continuous.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 6cd4a0f..0b28f98 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -15,6 +15,7 @@ from tianshou.utils.net.common import ( TLinearLayer, get_output_dim, ) +from tianshou.utils.pickle import setstate SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -109,6 +110,9 @@ class Critic(CriticBase): `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. :param linear_layer: use this module as linear layer. :param flatten_input: whether to flatten input data for the last layer. + :param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before + concatenating with the action) - and without the observations being modified in any way beforehand. + This allows the actor's preprocessing network to be reused for the critic. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. @@ -122,11 +126,13 @@ class Critic(CriticBase): preprocess_net_output_dim: int | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, + apply_preprocess_net_to_obs_only: bool = False, ) -> None: super().__init__() self.device = device self.preprocess = preprocess_net self.output_dim = 1 + self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( input_dim, @@ -137,6 +143,14 @@ class Critic(CriticBase): flatten_input=flatten_input, ) + def __setstate__(self, state: dict) -> None: + setstate( + Critic, + self, + state, + new_default_properties={"apply_preprocess_net_to_obs_only": False}, + ) + def forward( self, obs: np.ndarray | torch.Tensor, @@ -148,7 +162,10 @@ class Critic(CriticBase): obs, device=self.device, dtype=torch.float32, - ).flatten(1) + ) + if self.apply_preprocess_net_to_obs_only: + obs, _ = self.preprocess(obs) + obs = obs.flatten(1) if act is not None: act = torch.as_tensor( act, @@ -156,8 +173,9 @@ class Critic(CriticBase): dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) - values_B, hidden_BH = self.preprocess(obs) - return self.last(values_B) + if not self.apply_preprocess_net_to_obs_only: + obs, _ = self.preprocess(obs) + return self.last(obs) class ActorProb(BaseActor):