Cleaned up handling of output_dim retrieval, adding exceptions for erroneous cases
This commit is contained in:
parent
20074931d5
commit
022cfb7f78
@ -28,6 +28,7 @@ class ScaledObsInputModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.denom = denom
|
||||
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
|
||||
self.output_dim = module.output_dim
|
||||
|
||||
def forward(
|
||||
|
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeAlias, no_type_check
|
||||
from typing import Any, TypeAlias, TypeVar, no_type_check
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -13,6 +13,7 @@ ModuleType = type[nn.Module]
|
||||
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
||||
TActionShape: TypeAlias = Sequence[int] | int
|
||||
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def miniblock(
|
||||
@ -608,3 +609,39 @@ class BaseActor(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def get_output_dim(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
|
||||
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
|
||||
If both are present, they are required to match.
|
||||
|
||||
:param obj: the object from which to obtain the attribute value
|
||||
:param attr_name: the attribute name
|
||||
:param alt_value: the alternative value for the case where the attribute is not present, which cannot be None
|
||||
if the attribute is not present
|
||||
:return: the value
|
||||
"""
|
||||
v = getattr(obj, attr_name)
|
||||
if v is not None:
|
||||
if alt_value is not None and v != alt_value:
|
||||
raise ValueError(
|
||||
f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})",
|
||||
)
|
||||
return v
|
||||
else:
|
||||
if alt_value is None:
|
||||
raise ValueError(
|
||||
f"Attribute '{attr_name}' of {obj} is not defined and no fallback given",
|
||||
)
|
||||
return alt_value
|
||||
|
||||
|
||||
def get_output_dim(module: nn.Module, alt_value: int | None) -> int:
|
||||
"""Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present.
|
||||
If both are present, they must match.
|
||||
|
||||
:param module: the module
|
||||
:param alt_value: the alternative value
|
||||
:return: the value
|
||||
"""
|
||||
return getattr_with_matching_alt_value(module, "output_dim", alt_value)
|
||||
|
@ -1,12 +1,18 @@
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
|
||||
from tianshou.utils.net.common import (
|
||||
MLP,
|
||||
BaseActor,
|
||||
TActionShape,
|
||||
TLinearLayer,
|
||||
get_output_dim,
|
||||
)
|
||||
|
||||
SIGMA_MIN = -20
|
||||
SIGMA_MAX = 2
|
||||
@ -50,8 +56,7 @@ class Actor(BaseActor):
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
input_dim = cast(int, input_dim)
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(
|
||||
input_dim,
|
||||
self.output_dim,
|
||||
@ -118,9 +123,9 @@ class Critic(nn.Module):
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = 1
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(
|
||||
input_dim, # type: ignore
|
||||
input_dim,
|
||||
1,
|
||||
hidden_sizes,
|
||||
device=self.device,
|
||||
@ -199,12 +204,12 @@ class ActorProb(BaseActor):
|
||||
self.preprocess = preprocess_net
|
||||
self.device = device
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
|
||||
self._c_sigma = conditioned_sigma
|
||||
if conditioned_sigma:
|
||||
self.sigma = MLP(
|
||||
input_dim, # type: ignore
|
||||
input_dim,
|
||||
self.output_dim,
|
||||
hidden_sizes,
|
||||
device=self.device,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
|
||||
|
||||
|
||||
class Actor(BaseActor):
|
||||
@ -51,8 +51,7 @@ class Actor(BaseActor):
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
input_dim = cast(int, input_dim)
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(
|
||||
input_dim,
|
||||
self.output_dim,
|
||||
@ -118,8 +117,8 @@ class Critic(nn.Module):
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = last_size
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) # type: ignore
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
|
||||
|
||||
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||
"""Mapping: s -> V(s)."""
|
||||
@ -197,8 +196,8 @@ class ImplicitQuantileNetwork(Critic):
|
||||
) -> None:
|
||||
last_size = int(np.prod(action_shape))
|
||||
super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device)
|
||||
self.input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( # type: ignore
|
||||
self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to(
|
||||
device,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user