continuous.Critic: Add flag apply_preprocess_net_to_obs_only to allow the
preprocessing network to be applied to the observations only (without the actions concatenated), which is essential for the case where we want to reuse the actor's preprocessing network
This commit is contained in:
parent
18ed981875
commit
0b494845c9
@ -15,6 +15,7 @@ from tianshou.utils.net.common import (
|
||||
TLinearLayer,
|
||||
get_output_dim,
|
||||
)
|
||||
from tianshou.utils.pickle import setstate
|
||||
|
||||
SIGMA_MIN = -20
|
||||
SIGMA_MAX = 2
|
||||
@ -109,6 +110,9 @@ class Critic(CriticBase):
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
:param linear_layer: use this module as linear layer.
|
||||
:param flatten_input: whether to flatten input data for the last layer.
|
||||
:param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before
|
||||
concatenating with the action) - and without the observations being modified in any way beforehand.
|
||||
This allows the actor's preprocessing network to be reused for the critic.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
@ -122,11 +126,13 @@ class Critic(CriticBase):
|
||||
preprocess_net_output_dim: int | None = None,
|
||||
linear_layer: TLinearLayer = nn.Linear,
|
||||
flatten_input: bool = True,
|
||||
apply_preprocess_net_to_obs_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = 1
|
||||
self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(
|
||||
input_dim,
|
||||
@ -137,6 +143,14 @@ class Critic(CriticBase):
|
||||
flatten_input=flatten_input,
|
||||
)
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
setstate(
|
||||
Critic,
|
||||
self,
|
||||
state,
|
||||
new_default_properties={"apply_preprocess_net_to_obs_only": False},
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
@ -148,7 +162,10 @@ class Critic(CriticBase):
|
||||
obs,
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
)
|
||||
if self.apply_preprocess_net_to_obs_only:
|
||||
obs, _ = self.preprocess(obs)
|
||||
obs = obs.flatten(1)
|
||||
if act is not None:
|
||||
act = torch.as_tensor(
|
||||
act,
|
||||
@ -156,8 +173,9 @@ class Critic(CriticBase):
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
obs = torch.cat([obs, act], dim=1)
|
||||
values_B, hidden_BH = self.preprocess(obs)
|
||||
return self.last(values_B)
|
||||
if not self.apply_preprocess_net_to_obs_only:
|
||||
obs, _ = self.preprocess(obs)
|
||||
return self.last(obs)
|
||||
|
||||
|
||||
class ActorProb(BaseActor):
|
||||
|
Loading…
x
Reference in New Issue
Block a user