diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 6cd4a0f..0b28f98 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -15,6 +15,7 @@ from tianshou.utils.net.common import ( TLinearLayer, get_output_dim, ) +from tianshou.utils.pickle import setstate SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -109,6 +110,9 @@ class Critic(CriticBase): `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. :param linear_layer: use this module as linear layer. :param flatten_input: whether to flatten input data for the last layer. + :param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before + concatenating with the action) - and without the observations being modified in any way beforehand. + This allows the actor's preprocessing network to be reused for the critic. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. @@ -122,11 +126,13 @@ class Critic(CriticBase): preprocess_net_output_dim: int | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, + apply_preprocess_net_to_obs_only: bool = False, ) -> None: super().__init__() self.device = device self.preprocess = preprocess_net self.output_dim = 1 + self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( input_dim, @@ -137,6 +143,14 @@ class Critic(CriticBase): flatten_input=flatten_input, ) + def __setstate__(self, state: dict) -> None: + setstate( + Critic, + self, + state, + new_default_properties={"apply_preprocess_net_to_obs_only": False}, + ) + def forward( self, obs: np.ndarray | torch.Tensor, @@ -148,7 +162,10 @@ class Critic(CriticBase): obs, device=self.device, dtype=torch.float32, - ).flatten(1) + ) + if self.apply_preprocess_net_to_obs_only: + obs, _ = self.preprocess(obs) + obs = obs.flatten(1) if act is not None: act = torch.as_tensor( act, @@ -156,8 +173,9 @@ class Critic(CriticBase): dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) - values_B, hidden_BH = self.preprocess(obs) - return self.last(values_B) + if not self.apply_preprocess_net_to_obs_only: + obs, _ = self.preprocess(obs) + return self.last(obs) class ActorProb(BaseActor):