Cleaned up handling of output_dim retrieval, adding exceptions for erroneous cases
This commit is contained in:
parent
20074931d5
commit
022cfb7f78
@ -28,6 +28,7 @@ class ScaledObsInputModule(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
self.denom = denom
|
self.denom = denom
|
||||||
|
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
|
||||||
self.output_dim = module.output_dim
|
self.output_dim = module.output_dim
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, TypeAlias, no_type_check
|
from typing import Any, TypeAlias, TypeVar, no_type_check
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -13,6 +13,7 @@ ModuleType = type[nn.Module]
|
|||||||
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
||||||
TActionShape: TypeAlias = Sequence[int] | int
|
TActionShape: TypeAlias = Sequence[int] | int
|
||||||
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
|
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def miniblock(
|
def miniblock(
|
||||||
@ -608,3 +609,39 @@ class BaseActor(nn.Module, ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_output_dim(self) -> int:
|
def get_output_dim(self) -> int:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
|
||||||
|
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
|
||||||
|
If both are present, they are required to match.
|
||||||
|
|
||||||
|
:param obj: the object from which to obtain the attribute value
|
||||||
|
:param attr_name: the attribute name
|
||||||
|
:param alt_value: the alternative value for the case where the attribute is not present, which cannot be None
|
||||||
|
if the attribute is not present
|
||||||
|
:return: the value
|
||||||
|
"""
|
||||||
|
v = getattr(obj, attr_name)
|
||||||
|
if v is not None:
|
||||||
|
if alt_value is not None and v != alt_value:
|
||||||
|
raise ValueError(
|
||||||
|
f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})",
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
else:
|
||||||
|
if alt_value is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Attribute '{attr_name}' of {obj} is not defined and no fallback given",
|
||||||
|
)
|
||||||
|
return alt_value
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_dim(module: nn.Module, alt_value: int | None) -> int:
|
||||||
|
"""Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present.
|
||||||
|
If both are present, they must match.
|
||||||
|
|
||||||
|
:param module: the module
|
||||||
|
:param alt_value: the alternative value
|
||||||
|
:return: the value
|
||||||
|
"""
|
||||||
|
return getattr_with_matching_alt_value(module, "output_dim", alt_value)
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, cast
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
|
from tianshou.utils.net.common import (
|
||||||
|
MLP,
|
||||||
|
BaseActor,
|
||||||
|
TActionShape,
|
||||||
|
TLinearLayer,
|
||||||
|
get_output_dim,
|
||||||
|
)
|
||||||
|
|
||||||
SIGMA_MIN = -20
|
SIGMA_MIN = -20
|
||||||
SIGMA_MAX = 2
|
SIGMA_MAX = 2
|
||||||
@ -50,8 +56,7 @@ class Actor(BaseActor):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = int(np.prod(action_shape))
|
self.output_dim = int(np.prod(action_shape))
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||||
input_dim = cast(int, input_dim)
|
|
||||||
self.last = MLP(
|
self.last = MLP(
|
||||||
input_dim,
|
input_dim,
|
||||||
self.output_dim,
|
self.output_dim,
|
||||||
@ -118,9 +123,9 @@ class Critic(nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = 1
|
self.output_dim = 1
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||||
self.last = MLP(
|
self.last = MLP(
|
||||||
input_dim, # type: ignore
|
input_dim,
|
||||||
1,
|
1,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@ -199,12 +204,12 @@ class ActorProb(BaseActor):
|
|||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.device = device
|
self.device = device
|
||||||
self.output_dim = int(np.prod(action_shape))
|
self.output_dim = int(np.prod(action_shape))
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore
|
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
|
||||||
self._c_sigma = conditioned_sigma
|
self._c_sigma = conditioned_sigma
|
||||||
if conditioned_sigma:
|
if conditioned_sigma:
|
||||||
self.sigma = MLP(
|
self.sigma = MLP(
|
||||||
input_dim, # type: ignore
|
input_dim,
|
||||||
self.output_dim,
|
self.output_dim,
|
||||||
hidden_sizes,
|
hidden_sizes,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, cast
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.data import Batch, to_torch
|
from tianshou.data import Batch, to_torch
|
||||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
|
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
|
||||||
|
|
||||||
|
|
||||||
class Actor(BaseActor):
|
class Actor(BaseActor):
|
||||||
@ -51,8 +51,7 @@ class Actor(BaseActor):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = int(np.prod(action_shape))
|
self.output_dim = int(np.prod(action_shape))
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||||
input_dim = cast(int, input_dim)
|
|
||||||
self.last = MLP(
|
self.last = MLP(
|
||||||
input_dim,
|
input_dim,
|
||||||
self.output_dim,
|
self.output_dim,
|
||||||
@ -118,8 +117,8 @@ class Critic(nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = last_size
|
self.output_dim = last_size
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||||
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) # type: ignore
|
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
|
||||||
|
|
||||||
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||||
"""Mapping: s -> V(s)."""
|
"""Mapping: s -> V(s)."""
|
||||||
@ -197,8 +196,8 @@ class ImplicitQuantileNetwork(Critic):
|
|||||||
) -> None:
|
) -> None:
|
||||||
last_size = int(np.prod(action_shape))
|
last_size = int(np.prod(action_shape))
|
||||||
super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device)
|
super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device)
|
||||||
self.input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||||
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( # type: ignore
|
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to(
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user