Cleaned up handling of output_dim retrieval, adding exceptions for erroneous cases

This commit is contained in:
Dominik Jain 2024-01-16 13:25:41 +01:00
parent 20074931d5
commit 022cfb7f78
4 changed files with 60 additions and 18 deletions

View File

@ -28,6 +28,7 @@ class ScaledObsInputModule(torch.nn.Module):
super().__init__()
self.module = module
self.denom = denom
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
self.output_dim = module.output_dim
def forward(

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, no_type_check
from typing import Any, TypeAlias, TypeVar, no_type_check
import numpy as np
import torch
@ -13,6 +13,7 @@ ModuleType = type[nn.Module]
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
TActionShape: TypeAlias = Sequence[int] | int
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
T = TypeVar("T")
def miniblock(
@ -608,3 +609,39 @@ class BaseActor(nn.Module, ABC):
@abstractmethod
def get_output_dim(self) -> int:
pass
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
If both are present, they are required to match.
:param obj: the object from which to obtain the attribute value
:param attr_name: the attribute name
:param alt_value: the alternative value for the case where the attribute is not present, which cannot be None
if the attribute is not present
:return: the value
"""
v = getattr(obj, attr_name)
if v is not None:
if alt_value is not None and v != alt_value:
raise ValueError(
f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})",
)
return v
else:
if alt_value is None:
raise ValueError(
f"Attribute '{attr_name}' of {obj} is not defined and no fallback given",
)
return alt_value
def get_output_dim(module: nn.Module, alt_value: int | None) -> int:
"""Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present.
If both are present, they must match.
:param module: the module
:param alt_value: the alternative value
:return: the value
"""
return getattr_with_matching_alt_value(module, "output_dim", alt_value)

View File

@ -1,12 +1,18 @@
import warnings
from collections.abc import Sequence
from typing import Any, cast
from typing import Any
import numpy as np
import torch
from torch import nn
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
from tianshou.utils.net.common import (
MLP,
BaseActor,
TActionShape,
TLinearLayer,
get_output_dim,
)
SIGMA_MIN = -20
SIGMA_MAX = 2
@ -50,8 +56,7 @@ class Actor(BaseActor):
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = cast(int, input_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim,
self.output_dim,
@ -118,9 +123,9 @@ class Critic(nn.Module):
self.device = device
self.preprocess = preprocess_net
self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim, # type: ignore
input_dim,
1,
hidden_sizes,
device=self.device,
@ -199,12 +204,12 @@ class ActorProb(BaseActor):
self.preprocess = preprocess_net
self.device = device
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = MLP(
input_dim, # type: ignore
input_dim,
self.output_dim,
hidden_sizes,
device=self.device,

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any, cast
from typing import Any
import numpy as np
import torch
@ -7,7 +7,7 @@ import torch.nn.functional as F
from torch import nn
from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
class Actor(BaseActor):
@ -51,8 +51,7 @@ class Actor(BaseActor):
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = cast(int, input_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim,
self.output_dim,
@ -118,8 +117,8 @@ class Critic(nn.Module):
self.device = device
self.preprocess = preprocess_net
self.output_dim = last_size
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) # type: ignore
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Mapping: s -> V(s)."""
@ -197,8 +196,8 @@ class ImplicitQuantileNetwork(Critic):
) -> None:
last_size = int(np.prod(action_shape))
super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device)
self.input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( # type: ignore
self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to(
device,
)