161 lines
5.4 KiB
Python
Raw Normal View History

2023-09-20 09:29:34 +02:00
from abc import ABC, abstractmethod
from collections.abc import Sequence
2023-09-20 09:29:34 +02:00
import numpy as np
import torch
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.utils.net.common import Net
2023-09-20 09:29:34 +02:00
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.continuous import Critic as ContinuousCritic
TDevice = str | int | torch.device
2023-09-20 09:29:34 +02:00
def init_linear_orthogonal(module: torch.nn.Module):
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
2023-09-20 09:29:34 +02:00
:param module: the module whose submodules are to be processed
"""
2023-09-20 09:29:34 +02:00
for m in module.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
class ActorFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
pass
@staticmethod
def _init_linear(actor: torch.nn.Module):
2023-09-20 09:29:34 +02:00
"""Initializes linear layers of an actor module using default mechanisms.
:param module: the actor module.
"""
init_linear_orthogonal(actor)
if hasattr(actor, "mu"):
# For continuous action spaces with Gaussian policies
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
m.weight.data.copy_(0.01 * m.weight.data)
class DefaultActorFactory(ActorFactory):
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(
self,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded=False,
continuous_conditioned_sigma=False,
):
self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes
"""
An actor factory which, depending on the type of environment, creates a suitable MLP-based policy
"""
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = ContinuousActorProbFactory(
self.hidden_sizes,
unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma,
)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class ContinuousActorFactory(ActorFactory, ABC):
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
class ContinuousActorProbFactory(ContinuousActorFactory):
2023-09-20 09:29:34 +02:00
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
self.hidden_sizes = hidden_sizes
2023-09-20 09:29:34 +02:00
self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net(
2023-09-20 09:29:34 +02:00
envs.get_state_shape(),
hidden_sizes=self.hidden_sizes,
activation=nn.Tanh,
device=device,
)
2023-09-20 09:29:34 +02:00
actor = ActorProb(
net_a,
envs.get_action_shape(),
2023-09-20 13:15:06 +02:00
unbounded=self.unbounded,
2023-09-20 09:29:34 +02:00
device=device,
conditioned_sigma=self.conditioned_sigma,
).to(device)
# init params
2023-09-20 09:29:34 +02:00
if not self.conditioned_sigma:
torch.nn.init.constant_(actor.sigma_param, -0.5)
self._init_linear(actor)
return actor
class CriticFactory(ABC):
@abstractmethod
2023-09-20 09:29:34 +02:00
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
class DefaultCriticFactory(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = ContinuousNetCriticFactory(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class ContinuousCriticFactory(CriticFactory, ABC):
pass
class ContinuousNetCriticFactory(ContinuousCriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
2023-09-20 09:29:34 +02:00
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
2023-09-20 09:29:34 +02:00
envs.get_state_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = ContinuousCritic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic