From 022cfb7f78bd29c516524f2dce87c62c65acad14 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 16 Jan 2024 13:25:41 +0100 Subject: [PATCH] Cleaned up handling of output_dim retrieval, adding exceptions for erroneous cases --- examples/atari/atari_network.py | 1 + tianshou/utils/net/common.py | 39 +++++++++++++++++++++++++++++++- tianshou/utils/net/continuous.py | 23 +++++++++++-------- tianshou/utils/net/discrete.py | 15 ++++++------ 4 files changed, 60 insertions(+), 18 deletions(-) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 2e49478..7266ead 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -28,6 +28,7 @@ class ScaledObsInputModule(torch.nn.Module): super().__init__() self.module = module self.denom = denom + # This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim) self.output_dim = module.output_dim def forward( diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index d1a1135..3c3bf55 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import Any, TypeAlias, no_type_check +from typing import Any, TypeAlias, TypeVar, no_type_check import numpy as np import torch @@ -13,6 +13,7 @@ ModuleType = type[nn.Module] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] TActionShape: TypeAlias = Sequence[int] | int TLinearLayer: TypeAlias = Callable[[int, int], nn.Module] +T = TypeVar("T") def miniblock( @@ -608,3 +609,39 @@ class BaseActor(nn.Module, ABC): @abstractmethod def get_output_dim(self) -> int: pass + + +def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: + """Gets the given attribute from the given object or takes the alternative value if it is not present. + If both are present, they are required to match. + + :param obj: the object from which to obtain the attribute value + :param attr_name: the attribute name + :param alt_value: the alternative value for the case where the attribute is not present, which cannot be None + if the attribute is not present + :return: the value + """ + v = getattr(obj, attr_name) + if v is not None: + if alt_value is not None and v != alt_value: + raise ValueError( + f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})", + ) + return v + else: + if alt_value is None: + raise ValueError( + f"Attribute '{attr_name}' of {obj} is not defined and no fallback given", + ) + return alt_value + + +def get_output_dim(module: nn.Module, alt_value: int | None) -> int: + """Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present. + If both are present, they must match. + + :param module: the module + :param alt_value: the alternative value + :return: the value + """ + return getattr_with_matching_alt_value(module, "output_dim", alt_value) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 31c9efb..f257f8a 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,12 +1,18 @@ import warnings from collections.abc import Sequence -from typing import Any, cast +from typing import Any import numpy as np import torch from torch import nn -from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer +from tianshou.utils.net.common import ( + MLP, + BaseActor, + TActionShape, + TLinearLayer, + get_output_dim, +) SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -50,8 +56,7 @@ class Actor(BaseActor): self.device = device self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) - input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - input_dim = cast(int, input_dim) + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( input_dim, self.output_dim, @@ -118,9 +123,9 @@ class Critic(nn.Module): self.device = device self.preprocess = preprocess_net self.output_dim = 1 - input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( - input_dim, # type: ignore + input_dim, 1, hidden_sizes, device=self.device, @@ -199,12 +204,12 @@ class ActorProb(BaseActor): self.preprocess = preprocess_net self.device = device self.output_dim = int(np.prod(action_shape)) - input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = MLP( - input_dim, # type: ignore + input_dim, self.output_dim, hidden_sizes, device=self.device, diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index ae77c5d..5417cb7 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, cast +from typing import Any import numpy as np import torch @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import nn from tianshou.data import Batch, to_torch -from tianshou.utils.net.common import MLP, BaseActor, TActionShape +from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim class Actor(BaseActor): @@ -51,8 +51,7 @@ class Actor(BaseActor): self.device = device self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) - input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - input_dim = cast(int, input_dim) + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( input_dim, self.output_dim, @@ -118,8 +117,8 @@ class Critic(nn.Module): self.device = device self.preprocess = preprocess_net self.output_dim = last_size - input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) # type: ignore + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: """Mapping: s -> V(s).""" @@ -197,8 +196,8 @@ class ImplicitQuantileNetwork(Critic): ) -> None: last_size = int(np.prod(action_shape)) super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device) - self.input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( # type: ignore + self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( device, )