* Use prefix convention (subclasses have superclass names as prefix) to facilitate discoverability of relevant classes via IDE autocompletion * Use dual naming, adding an alternative concise name that omits the precise OO semantics and retains only the essential part of the name (which can be more pleasing to users not accustomed to convoluted OO naming)
239 lines
8.0 KiB
Python
239 lines
8.0 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass
|
|
from typing import TypeAlias
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from tianshou.highlevel.env import Environments, EnvType
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
|
from tianshou.utils.net import continuous
|
|
from tianshou.utils.net.common import ActorCritic, Net
|
|
|
|
TDevice: TypeAlias = str | int | torch.device
|
|
|
|
|
|
def init_linear_orthogonal(module: torch.nn.Module):
|
|
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
|
|
|
:param module: the module whose submodules are to be processed
|
|
"""
|
|
for m in module.modules():
|
|
if isinstance(m, torch.nn.Linear):
|
|
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
|
torch.nn.init.zeros_(m.bias)
|
|
|
|
|
|
class ContinuousActorType:
|
|
GAUSSIAN = "gaussian"
|
|
DETERMINISTIC = "deterministic"
|
|
|
|
|
|
class ActorFactory(ABC):
|
|
@abstractmethod
|
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _init_linear(actor: torch.nn.Module):
|
|
"""Initializes linear layers of an actor module using default mechanisms.
|
|
|
|
:param module: the actor module.
|
|
"""
|
|
init_linear_orthogonal(actor)
|
|
if hasattr(actor, "mu"):
|
|
# For continuous action spaces with Gaussian policies
|
|
# do last policy layer scaling, this will make initial actions have (close to)
|
|
# 0 mean and std, and will help boost performances,
|
|
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
|
for m in actor.mu.modules():
|
|
if isinstance(m, torch.nn.Linear):
|
|
m.weight.data.copy_(0.01 * m.weight.data)
|
|
|
|
|
|
class ActorFactoryDefault(ActorFactory):
|
|
"""An actor factory which, depending on the type of environment, creates a suitable MLP-based policy."""
|
|
|
|
DEFAULT_HIDDEN_SIZES = (64, 64)
|
|
|
|
def __init__(
|
|
self,
|
|
continuous_actor_type: ContinuousActorType,
|
|
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
|
|
continuous_unbounded=False,
|
|
continuous_conditioned_sigma=False,
|
|
):
|
|
self.continuous_actor_type = continuous_actor_type
|
|
self.continuous_unbounded = continuous_unbounded
|
|
self.continuous_conditioned_sigma = continuous_conditioned_sigma
|
|
self.hidden_sizes = hidden_sizes
|
|
|
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
|
env_type = envs.get_type()
|
|
if env_type == EnvType.CONTINUOUS:
|
|
match self.continuous_actor_type:
|
|
case ContinuousActorType.GAUSSIAN:
|
|
factory = ActorFactoryContinuousGaussian(
|
|
self.hidden_sizes,
|
|
unbounded=self.continuous_unbounded,
|
|
conditioned_sigma=self.continuous_conditioned_sigma,
|
|
)
|
|
case ContinuousActorType.DETERMINISTIC:
|
|
factory = ActorFactoryContinuousDeterministic(self.hidden_sizes)
|
|
case _:
|
|
raise ValueError(self.continuous_actor_type)
|
|
return factory.create_module(envs, device)
|
|
elif env_type == EnvType.DISCRETE:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError(f"{env_type} not supported")
|
|
|
|
|
|
class ActorFactoryContinuous(ActorFactory, ABC):
|
|
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
|
|
|
|
|
|
class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
|
|
def __init__(self, hidden_sizes: Sequence[int]):
|
|
self.hidden_sizes = hidden_sizes
|
|
|
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
|
net_a = Net(
|
|
envs.get_observation_shape(),
|
|
hidden_sizes=self.hidden_sizes,
|
|
device=device,
|
|
)
|
|
return continuous.Actor(
|
|
net_a,
|
|
envs.get_action_shape(),
|
|
hidden_sizes=(),
|
|
device=device,
|
|
).to(device)
|
|
|
|
|
|
class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
|
|
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
|
self.hidden_sizes = hidden_sizes
|
|
self.unbounded = unbounded
|
|
self.conditioned_sigma = conditioned_sigma
|
|
|
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
|
net_a = Net(
|
|
envs.get_observation_shape(),
|
|
hidden_sizes=self.hidden_sizes,
|
|
activation=nn.Tanh,
|
|
device=device,
|
|
)
|
|
actor = continuous.ActorProb(
|
|
net_a,
|
|
envs.get_action_shape(),
|
|
unbounded=self.unbounded,
|
|
device=device,
|
|
conditioned_sigma=self.conditioned_sigma,
|
|
).to(device)
|
|
|
|
# init params
|
|
if not self.conditioned_sigma:
|
|
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
|
self._init_linear(actor)
|
|
|
|
return actor
|
|
|
|
|
|
class CriticFactory(ABC):
|
|
@abstractmethod
|
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
|
pass
|
|
|
|
|
|
class CriticFactoryDefault(CriticFactory):
|
|
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
|
|
|
|
DEFAULT_HIDDEN_SIZES = (64, 64)
|
|
|
|
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
|
|
self.hidden_sizes = hidden_sizes
|
|
|
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
|
env_type = envs.get_type()
|
|
if env_type == EnvType.CONTINUOUS:
|
|
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
|
return factory.create_module(envs, device, use_action)
|
|
elif env_type == EnvType.DISCRETE:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError(f"{env_type} not supported")
|
|
|
|
|
|
class CriticFactoryContinuous(CriticFactory, ABC):
|
|
pass
|
|
|
|
|
|
class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
|
def __init__(self, hidden_sizes: Sequence[int]):
|
|
self.hidden_sizes = hidden_sizes
|
|
|
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
|
action_shape = envs.get_action_shape() if use_action else 0
|
|
net_c = Net(
|
|
envs.get_observation_shape(),
|
|
action_shape=action_shape,
|
|
hidden_sizes=self.hidden_sizes,
|
|
concat=use_action,
|
|
activation=nn.Tanh,
|
|
device=device,
|
|
)
|
|
critic = continuous.Critic(net_c, device=device).to(device)
|
|
init_linear_orthogonal(critic)
|
|
return critic
|
|
|
|
|
|
@dataclass
|
|
class ModuleOpt:
|
|
module: torch.nn.Module
|
|
optim: torch.optim.Optimizer
|
|
|
|
|
|
@dataclass
|
|
class ActorCriticModuleOpt:
|
|
actor_critic_module: ActorCritic
|
|
optim: torch.optim.Optimizer
|
|
|
|
@property
|
|
def actor(self):
|
|
return self.actor_critic_module.actor
|
|
|
|
@property
|
|
def critic(self):
|
|
return self.actor_critic_module.critic
|
|
|
|
|
|
class ActorModuleOptFactory:
|
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
|
self.actor_factory = actor_factory
|
|
self.optim_factory = optim_factory
|
|
|
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
actor = self.actor_factory.create_module(envs, device)
|
|
opt = self.optim_factory.create_optimizer(actor, lr)
|
|
return ModuleOpt(actor, opt)
|
|
|
|
|
|
class CriticModuleOptFactory:
|
|
def __init__(
|
|
self,
|
|
critic_factory: CriticFactory,
|
|
optim_factory: OptimizerFactory,
|
|
use_action: bool,
|
|
):
|
|
self.critic_factory = critic_factory
|
|
self.optim_factory = optim_factory
|
|
self.use_action = use_action
|
|
|
|
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
critic = self.critic_factory.create_module(envs, device, self.use_action)
|
|
opt = self.optim_factory.create_optimizer(critic, lr)
|
|
return ModuleOpt(critic, opt)
|