Cleaned up handling of output_dim retrieval, adding exceptions for erroneous cases

This commit is contained in:
Dominik Jain 2024-01-16 13:25:41 +01:00
parent 20074931d5
commit 022cfb7f78
4 changed files with 60 additions and 18 deletions

View File

@ -28,6 +28,7 @@ class ScaledObsInputModule(torch.nn.Module):
super().__init__() super().__init__()
self.module = module self.module = module
self.denom = denom self.denom = denom
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
self.output_dim = module.output_dim self.output_dim = module.output_dim
def forward( def forward(

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, no_type_check from typing import Any, TypeAlias, TypeVar, no_type_check
import numpy as np import numpy as np
import torch import torch
@ -13,6 +13,7 @@ ModuleType = type[nn.Module]
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
TActionShape: TypeAlias = Sequence[int] | int TActionShape: TypeAlias = Sequence[int] | int
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module] TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
T = TypeVar("T")
def miniblock( def miniblock(
@ -608,3 +609,39 @@ class BaseActor(nn.Module, ABC):
@abstractmethod @abstractmethod
def get_output_dim(self) -> int: def get_output_dim(self) -> int:
pass pass
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
If both are present, they are required to match.
:param obj: the object from which to obtain the attribute value
:param attr_name: the attribute name
:param alt_value: the alternative value for the case where the attribute is not present, which cannot be None
if the attribute is not present
:return: the value
"""
v = getattr(obj, attr_name)
if v is not None:
if alt_value is not None and v != alt_value:
raise ValueError(
f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})",
)
return v
else:
if alt_value is None:
raise ValueError(
f"Attribute '{attr_name}' of {obj} is not defined and no fallback given",
)
return alt_value
def get_output_dim(module: nn.Module, alt_value: int | None) -> int:
"""Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present.
If both are present, they must match.
:param module: the module
:param alt_value: the alternative value
:return: the value
"""
return getattr_with_matching_alt_value(module, "output_dim", alt_value)

View File

@ -1,12 +1,18 @@
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, cast from typing import Any
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer from tianshou.utils.net.common import (
MLP,
BaseActor,
TActionShape,
TLinearLayer,
get_output_dim,
)
SIGMA_MIN = -20 SIGMA_MIN = -20
SIGMA_MAX = 2 SIGMA_MAX = 2
@ -50,8 +56,7 @@ class Actor(BaseActor):
self.device = device self.device = device
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape)) self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
input_dim = cast(int, input_dim)
self.last = MLP( self.last = MLP(
input_dim, input_dim,
self.output_dim, self.output_dim,
@ -118,9 +123,9 @@ class Critic(nn.Module):
self.device = device self.device = device
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.output_dim = 1 self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP( self.last = MLP(
input_dim, # type: ignore input_dim,
1, 1,
hidden_sizes, hidden_sizes,
device=self.device, device=self.device,
@ -199,12 +204,12 @@ class ActorProb(BaseActor):
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.device = device self.device = device
self.output_dim = int(np.prod(action_shape)) self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
self._c_sigma = conditioned_sigma self._c_sigma = conditioned_sigma
if conditioned_sigma: if conditioned_sigma:
self.sigma = MLP( self.sigma = MLP(
input_dim, # type: ignore input_dim,
self.output_dim, self.output_dim,
hidden_sizes, hidden_sizes,
device=self.device, device=self.device,

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, cast from typing import Any
import numpy as np import numpy as np
import torch import torch
@ -7,7 +7,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from tianshou.data import Batch, to_torch from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor, TActionShape from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
class Actor(BaseActor): class Actor(BaseActor):
@ -51,8 +51,7 @@ class Actor(BaseActor):
self.device = device self.device = device
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape)) self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
input_dim = cast(int, input_dim)
self.last = MLP( self.last = MLP(
input_dim, input_dim,
self.output_dim, self.output_dim,
@ -118,8 +117,8 @@ class Critic(nn.Module):
self.device = device self.device = device
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.output_dim = last_size self.output_dim = last_size
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) # type: ignore self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Mapping: s -> V(s).""" """Mapping: s -> V(s)."""
@ -197,8 +196,8 @@ class ImplicitQuantileNetwork(Critic):
) -> None: ) -> None:
last_size = int(np.prod(action_shape)) last_size = int(np.prod(action_shape))
super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device) super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device)
self.input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( # type: ignore self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to(
device, device,
) )