2023-09-20 09:29:34 +02:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from collections.abc import Sequence
|
2023-09-19 18:53:11 +02:00
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
import numpy as np
|
2023-09-19 18:53:11 +02:00
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
from tianshou.highlevel.env import Environments
|
|
|
|
from tianshou.utils.net.common import Net
|
2023-09-20 09:29:34 +02:00
|
|
|
from tianshou.utils.net.continuous import ActorProb
|
|
|
|
from tianshou.utils.net.continuous import Critic as ContinuousCritic
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
TDevice = str | int | torch.device
|
|
|
|
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
def init_linear_orthogonal(module: torch.nn.Module):
|
|
|
|
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
|
2023-09-19 18:53:11 +02:00
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
:param module: the module whose submodules are to be processed
|
2023-09-19 18:53:11 +02:00
|
|
|
"""
|
2023-09-20 09:29:34 +02:00
|
|
|
for m in module.modules():
|
2023-09-19 18:53:11 +02:00
|
|
|
if isinstance(m, torch.nn.Linear):
|
|
|
|
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
|
|
|
torch.nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
|
|
|
|
class ActorFactory(ABC):
|
|
|
|
@abstractmethod
|
|
|
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _init_linear(actor: torch.nn.Module):
|
2023-09-20 09:29:34 +02:00
|
|
|
"""Initializes linear layers of an actor module using default mechanisms.
|
|
|
|
|
|
|
|
:param module: the actor module.
|
2023-09-19 18:53:11 +02:00
|
|
|
"""
|
|
|
|
init_linear_orthogonal(actor)
|
|
|
|
if hasattr(actor, "mu"):
|
|
|
|
# For continuous action spaces with Gaussian policies
|
|
|
|
# do last policy layer scaling, this will make initial actions have (close to)
|
|
|
|
# 0 mean and std, and will help boost performances,
|
|
|
|
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
|
|
|
for m in actor.mu.modules():
|
|
|
|
if isinstance(m, torch.nn.Linear):
|
|
|
|
m.weight.data.copy_(0.01 * m.weight.data)
|
|
|
|
|
|
|
|
|
|
|
|
class ContinuousActorFactory(ActorFactory, ABC):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class ContinuousActorProbFactory(ContinuousActorFactory):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
|
2023-09-19 18:53:11 +02:00
|
|
|
self.hidden_sizes = hidden_sizes
|
2023-09-20 09:29:34 +02:00
|
|
|
self.unbounded = unbounded
|
|
|
|
self.conditioned_sigma = conditioned_sigma
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
|
|
|
net_a = Net(
|
2023-09-20 09:29:34 +02:00
|
|
|
envs.get_state_shape(),
|
|
|
|
hidden_sizes=self.hidden_sizes,
|
|
|
|
activation=nn.Tanh,
|
|
|
|
device=device,
|
2023-09-19 18:53:11 +02:00
|
|
|
)
|
2023-09-20 09:29:34 +02:00
|
|
|
actor = ActorProb(
|
|
|
|
net_a,
|
|
|
|
envs.get_action_shape(),
|
|
|
|
unbounded=True,
|
|
|
|
device=device,
|
|
|
|
conditioned_sigma=self.conditioned_sigma,
|
|
|
|
).to(device)
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
# init params
|
2023-09-20 09:29:34 +02:00
|
|
|
if not self.conditioned_sigma:
|
|
|
|
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
2023-09-19 18:53:11 +02:00
|
|
|
self._init_linear(actor)
|
|
|
|
|
|
|
|
return actor
|
|
|
|
|
|
|
|
|
|
|
|
class CriticFactory(ABC):
|
|
|
|
@abstractmethod
|
2023-09-20 09:29:34 +02:00
|
|
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
2023-09-19 18:53:11 +02:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class ContinuousCriticFactory(CriticFactory, ABC):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(self, hidden_sizes: Sequence[int], action_shape=0):
|
|
|
|
self.action_shape = action_shape
|
2023-09-19 18:53:11 +02:00
|
|
|
self.hidden_sizes = hidden_sizes
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
|
|
|
action_shape = envs.get_action_shape() if use_action else 0
|
2023-09-19 18:53:11 +02:00
|
|
|
net_c = Net(
|
2023-09-20 09:29:34 +02:00
|
|
|
envs.get_state_shape(),
|
|
|
|
action_shape=action_shape,
|
|
|
|
hidden_sizes=self.hidden_sizes,
|
|
|
|
concat=use_action,
|
|
|
|
activation=nn.Tanh,
|
|
|
|
device=device,
|
2023-09-19 18:53:11 +02:00
|
|
|
)
|
|
|
|
critic = ContinuousCritic(net_c, device=device).to(device)
|
|
|
|
init_linear_orthogonal(critic)
|
|
|
|
return critic
|