Dominik Jain 78b6dd1f49 Adapt class naming scheme
* Use prefix convention (subclasses have superclass names as prefix) to
  facilitate discoverability of relevant classes via IDE autocompletion
* Use dual naming, adding an alternative concise name that omits the
  precise OO semantics and retains only the essential part of the name
  (which can be more pleasing to users not accustomed to
  convoluted OO naming)
2023-10-18 20:44:16 +02:00

239 lines
8.0 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous
from tianshou.utils.net.common import ActorCritic, Net
TDevice: TypeAlias = str | int | torch.device
def init_linear_orthogonal(module: torch.nn.Module):
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
:param module: the module whose submodules are to be processed
"""
for m in module.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
class ContinuousActorType:
GAUSSIAN = "gaussian"
DETERMINISTIC = "deterministic"
class ActorFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
pass
@staticmethod
def _init_linear(actor: torch.nn.Module):
"""Initializes linear layers of an actor module using default mechanisms.
:param module: the actor module.
"""
init_linear_orthogonal(actor)
if hasattr(actor, "mu"):
# For continuous action spaces with Gaussian policies
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
m.weight.data.copy_(0.01 * m.weight.data)
class ActorFactoryDefault(ActorFactory):
"""An actor factory which, depending on the type of environment, creates a suitable MLP-based policy."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(
self,
continuous_actor_type: ContinuousActorType,
hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES,
continuous_unbounded=False,
continuous_conditioned_sigma=False,
):
self.continuous_actor_type = continuous_actor_type
self.continuous_unbounded = continuous_unbounded
self.continuous_conditioned_sigma = continuous_conditioned_sigma
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
match self.continuous_actor_type:
case ContinuousActorType.GAUSSIAN:
factory = ActorFactoryContinuousGaussian(
self.hidden_sizes,
unbounded=self.continuous_unbounded,
conditioned_sigma=self.continuous_conditioned_sigma,
)
case ContinuousActorType.DETERMINISTIC:
factory = ActorFactoryContinuousDeterministic(self.hidden_sizes)
case _:
raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class ActorFactoryContinuous(ActorFactory, ABC):
"""Serves as a type bound for actor factories that are suitable for continuous action spaces."""
class ActorFactoryContinuousDeterministic(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net(
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
device=device,
)
return continuous.Actor(
net_a,
envs.get_action_shape(),
hidden_sizes=(),
device=device,
).to(device)
class ActorFactoryContinuousGaussian(ActorFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
self.hidden_sizes = hidden_sizes
self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net(
envs.get_observation_shape(),
hidden_sizes=self.hidden_sizes,
activation=nn.Tanh,
device=device,
)
actor = continuous.ActorProb(
net_a,
envs.get_action_shape(),
unbounded=self.unbounded,
device=device,
conditioned_sigma=self.conditioned_sigma,
).to(device)
# init params
if not self.conditioned_sigma:
torch.nn.init.constant_(actor.sigma_param, -0.5)
self._init_linear(actor)
return actor
class CriticFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass
class CriticFactoryDefault(CriticFactory):
"""A critic factory which, depending on the type of environment, creates a suitable MLP-based critic."""
DEFAULT_HIDDEN_SIZES = (64, 64)
def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
env_type = envs.get_type()
if env_type == EnvType.CONTINUOUS:
factory = CriticFactoryContinuousNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
else:
raise ValueError(f"{env_type} not supported")
class CriticFactoryContinuous(CriticFactory, ABC):
pass
class CriticFactoryContinuousNet(CriticFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = continuous.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic
@dataclass
class ModuleOpt:
module: torch.nn.Module
optim: torch.optim.Optimizer
@dataclass
class ActorCriticModuleOpt:
actor_critic_module: ActorCritic
optim: torch.optim.Optimizer
@property
def actor(self):
return self.actor_critic_module.actor
@property
def critic(self):
return self.actor_critic_module.critic
class ActorModuleOptFactory:
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
actor = self.actor_factory.create_module(envs, device)
opt = self.optim_factory.create_optimizer(actor, lr)
return ModuleOpt(actor, opt)
class CriticModuleOptFactory:
def __init__(
self,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
use_action: bool,
):
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.use_action = use_action
def create_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
critic = self.critic_factory.create_module(envs, device, self.use_action)
opt = self.optim_factory.create_optimizer(critic, lr)
return ModuleOpt(critic, opt)