| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  | from abc import ABC, abstractmethod | 
					
						
							|  |  |  | from collections.abc import Sequence | 
					
						
							| 
									
										
										
										
											2023-10-10 19:11:49 +02:00
										 |  |  | from dataclasses import dataclass | 
					
						
							| 
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 |  |  | from enum import Enum | 
					
						
							| 
									
										
										
										
											2023-10-10 19:11:49 +02:00
										 |  |  | from typing import Protocol | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | from torch import nn | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  | from tianshou.highlevel.env import Environments, EnvType | 
					
						
							| 
									
										
										
										
											2023-10-11 15:31:38 +02:00
										 |  |  | from tianshou.highlevel.module.core import ( | 
					
						
							|  |  |  |     ModuleFactory, | 
					
						
							|  |  |  |     TDevice, | 
					
						
							|  |  |  |     init_linear_orthogonal, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2023-10-16 18:19:31 +02:00
										 |  |  | from tianshou.highlevel.module.intermediate import ( | 
					
						
							|  |  |  |     IntermediateModule, | 
					
						
							|  |  |  |     IntermediateModuleFactory, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  | from tianshou.highlevel.module.module_opt import ModuleOpt | 
					
						
							|  |  |  | from tianshou.highlevel.optim import OptimizerFactory | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  | from tianshou.utils.net import continuous, discrete | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  | from tianshou.utils.net.common import BaseActor, ModuleType, Net | 
					
						
							| 
									
										
										
										
											2023-11-07 10:54:22 +01:00
										 |  |  | from tianshou.utils.string import ToStringMixin | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 |  |  | class ContinuousActorType(Enum): | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |     GAUSSIAN = "gaussian" | 
					
						
							|  |  |  |     DETERMINISTIC = "deterministic" | 
					
						
							| 
									
										
										
										
											2023-10-05 19:21:08 +02:00
										 |  |  |     UNSUPPORTED = "unsupported" | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-10 19:11:49 +02:00
										 |  |  | @dataclass | 
					
						
							|  |  |  | class ActorFuture: | 
					
						
							|  |  |  |     """Container, which, in the future, will hold an actor instance.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     actor: BaseActor | nn.Module | None = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ActorFutureProviderProtocol(Protocol): | 
					
						
							|  |  |  |     def get_actor_future(self) -> ActorFuture: | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-11 15:31:38 +02:00
										 |  |  | class ActorFactory(ModuleFactory, ToStringMixin, ABC): | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |     @abstractmethod | 
					
						
							| 
									
										
										
										
											2023-10-05 19:21:08 +02:00
										 |  |  |     def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  |     def create_module_opt( | 
					
						
							| 
									
										
										
										
											2023-10-10 13:12:25 +02:00
										 |  |  |         self, | 
					
						
							|  |  |  |         envs: Environments, | 
					
						
							|  |  |  |         device: TDevice, | 
					
						
							|  |  |  |         optim_factory: OptimizerFactory, | 
					
						
							|  |  |  |         lr: float, | 
					
						
							| 
									
										
										
										
											2023-10-10 12:55:25 +02:00
										 |  |  |     ) -> ModuleOpt: | 
					
						
							|  |  |  |         """Creates the actor module along with its optimizer for the given learning rate.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param envs: the environments | 
					
						
							|  |  |  |         :param device: the torch device | 
					
						
							|  |  |  |         :param optim_factory: the optimizer factory | 
					
						
							|  |  |  |         :param lr: the learning rate | 
					
						
							|  |  |  |         :return: a container with the actor module and its optimizer | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         module = self.create_module(envs, device) | 
					
						
							|  |  |  |         optim = optim_factory.create_optimizer(module, lr) | 
					
						
							|  |  |  |         return ModuleOpt(module, optim) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |     @staticmethod | 
					
						
							| 
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 |  |  |     def _init_linear(actor: torch.nn.Module) -> None: | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  |         """Initializes linear layers of an actor module using default mechanisms.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param module: the actor module. | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         init_linear_orthogonal(actor) | 
					
						
							|  |  |  |         if hasattr(actor, "mu"): | 
					
						
							|  |  |  |             # For continuous action spaces with Gaussian policies | 
					
						
							|  |  |  |             # do last policy layer scaling, this will make initial actions have (close to) | 
					
						
							|  |  |  |             # 0 mean and std, and will help boost performances, | 
					
						
							|  |  |  |             # see https://arxiv.org/abs/2006.05990, Fig.24 for details | 
					
						
							| 
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 |  |  |             for m in actor.mu.modules():  # type: ignore | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |                 if isinstance(m, torch.nn.Linear): | 
					
						
							|  |  |  |                     m.weight.data.copy_(0.01 * m.weight.data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 17:20:35 +02:00
										 |  |  | class ActorFactoryDefault(ActorFactory): | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |     """An actor factory which, depending on the type of environment, creates a suitable MLP-based policy.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  |     DEFAULT_HIDDEN_SIZES = (64, 64) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |         continuous_actor_type: ContinuousActorType, | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  |         hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |         hidden_activation: ModuleType = nn.ReLU, | 
					
						
							| 
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 |  |  |         continuous_unbounded: bool = False, | 
					
						
							|  |  |  |         continuous_conditioned_sigma: bool = False, | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |         discrete_softmax: bool = True, | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  |     ): | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |         self.continuous_actor_type = continuous_actor_type | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  |         self.continuous_unbounded = continuous_unbounded | 
					
						
							|  |  |  |         self.continuous_conditioned_sigma = continuous_conditioned_sigma | 
					
						
							|  |  |  |         self.hidden_sizes = hidden_sizes | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |         self.hidden_activation = hidden_activation | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |         self.discrete_softmax = discrete_softmax | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  |     def create_module(self, envs: Environments, device: TDevice) -> BaseActor: | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  |         env_type = envs.get_type() | 
					
						
							| 
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 |  |  |         factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  |         if env_type == EnvType.CONTINUOUS: | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |             match self.continuous_actor_type: | 
					
						
							|  |  |  |                 case ContinuousActorType.GAUSSIAN: | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  |                     factory = ActorFactoryContinuousGaussianNet( | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |                         self.hidden_sizes, | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |                         activation=self.hidden_activation, | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |                         unbounded=self.continuous_unbounded, | 
					
						
							|  |  |  |                         conditioned_sigma=self.continuous_conditioned_sigma, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 case ContinuousActorType.DETERMINISTIC: | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |                     factory = ActorFactoryContinuousDeterministicNet( | 
					
						
							|  |  |  |                         self.hidden_sizes, | 
					
						
							|  |  |  |                         activation=self.hidden_activation, | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2023-10-05 19:21:08 +02:00
										 |  |  |                 case ContinuousActorType.UNSUPPORTED: | 
					
						
							|  |  |  |                     raise ValueError("Continuous action spaces are not supported by the algorithm") | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |                 case _: | 
					
						
							|  |  |  |                     raise ValueError(self.continuous_actor_type) | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  |             return factory.create_module(envs, device) | 
					
						
							|  |  |  |         elif env_type == EnvType.DISCRETE: | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |             factory = ActorFactoryDiscreteNet( | 
					
						
							| 
									
										
										
										
											2023-10-11 19:31:26 +02:00
										 |  |  |                 self.DEFAULT_HIDDEN_SIZES, | 
					
						
							|  |  |  |                 softmax_output=self.discrete_softmax, | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-10-05 19:21:08 +02:00
										 |  |  |             return factory.create_module(envs, device) | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  |         else: | 
					
						
							|  |  |  |             raise ValueError(f"{env_type} not supported") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 17:20:35 +02:00
										 |  |  | class ActorFactoryContinuous(ActorFactory, ABC): | 
					
						
							| 
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 |  |  |     """Serves as a type bound for actor factories that are suitable for continuous action spaces.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  | class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous): | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |     def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |         self.hidden_sizes = hidden_sizes | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |         self.activation = activation | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  |     def create_module(self, envs: Environments, device: TDevice) -> BaseActor: | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |         net_a = Net( | 
					
						
							|  |  |  |             envs.get_observation_shape(), | 
					
						
							|  |  |  |             hidden_sizes=self.hidden_sizes, | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |             activation=self.activation, | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |             device=device, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         return continuous.Actor( | 
					
						
							|  |  |  |             net_a, | 
					
						
							|  |  |  |             envs.get_action_shape(), | 
					
						
							|  |  |  |             hidden_sizes=(), | 
					
						
							|  |  |  |             device=device, | 
					
						
							|  |  |  |         ).to(device) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  | class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): | 
					
						
							| 
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         hidden_sizes: Sequence[int], | 
					
						
							|  |  |  |         unbounded: bool = True, | 
					
						
							|  |  |  |         conditioned_sigma: bool = False, | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |         activation: ModuleType = nn.ReLU, | 
					
						
							| 
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 |  |  |     ): | 
					
						
							| 
									
										
										
										
											2023-10-16 18:19:31 +02:00
										 |  |  |         """:param hidden_sizes: the sequence of hidden dimensions to use in the network structure
 | 
					
						
							|  |  |  |         :param unbounded: whether to apply tanh activation on final logits | 
					
						
							|  |  |  |         :param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the | 
					
						
							|  |  |  |             input; if False, sigma is an independent parameter | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |         self.hidden_sizes = hidden_sizes | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  |         self.unbounded = unbounded | 
					
						
							|  |  |  |         self.conditioned_sigma = conditioned_sigma | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |         self.activation = activation | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  |     def create_module(self, envs: Environments, device: TDevice) -> BaseActor: | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |         net_a = Net( | 
					
						
							| 
									
										
										
										
											2023-09-25 17:56:37 +02:00
										 |  |  |             envs.get_observation_shape(), | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  |             hidden_sizes=self.hidden_sizes, | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |             activation=self.activation, | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  |             device=device, | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 |  |  |         actor = continuous.ActorProb( | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  |             net_a, | 
					
						
							|  |  |  |             envs.get_action_shape(), | 
					
						
							| 
									
										
										
										
											2023-09-20 13:15:06 +02:00
										 |  |  |             unbounded=self.unbounded, | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  |             device=device, | 
					
						
							|  |  |  |             conditioned_sigma=self.conditioned_sigma, | 
					
						
							|  |  |  |         ).to(device) | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # init params | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  |         if not self.conditioned_sigma: | 
					
						
							|  |  |  |             torch.nn.init.constant_(actor.sigma_param, -0.5) | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |         self._init_linear(actor) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return actor | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  | class ActorFactoryDiscreteNet(ActorFactory): | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         hidden_sizes: Sequence[int], | 
					
						
							|  |  |  |         softmax_output: bool = True, | 
					
						
							|  |  |  |         activation: ModuleType = nn.ReLU, | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |         self.hidden_sizes = hidden_sizes | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |         self.softmax_output = softmax_output | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |         self.activation = activation | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  |     def create_module(self, envs: Environments, device: TDevice) -> BaseActor: | 
					
						
							|  |  |  |         net_a = Net( | 
					
						
							| 
									
										
										
										
											2023-09-25 17:56:37 +02:00
										 |  |  |             envs.get_observation_shape(), | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  |             hidden_sizes=self.hidden_sizes, | 
					
						
							| 
									
										
										
										
											2023-10-18 13:57:36 +02:00
										 |  |  |             activation=self.activation, | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  |             device=device, | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  |         return discrete.Actor( | 
					
						
							|  |  |  |             net_a, | 
					
						
							|  |  |  |             envs.get_action_shape(), | 
					
						
							|  |  |  |             hidden_sizes=(), | 
					
						
							|  |  |  |             device=device, | 
					
						
							| 
									
										
										
										
											2023-10-11 16:07:34 +02:00
										 |  |  |             softmax_output=self.softmax_output, | 
					
						
							| 
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 |  |  |         ).to(device) | 
					
						
							| 
									
										
										
										
											2023-10-10 19:11:49 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ActorFactoryTransientStorageDecorator(ActorFactory): | 
					
						
							| 
									
										
										
										
											2023-11-06 17:17:43 +01:00
										 |  |  |     """Wraps an actor factory, storing the most recently created actor instance such that it
 | 
					
						
							|  |  |  |     can be retrieved. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-10 19:11:49 +02:00
										 |  |  |     def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture): | 
					
						
							|  |  |  |         self.actor_factory = actor_factory | 
					
						
							|  |  |  |         self._actor_future = actor_future | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-13 12:25:28 +02:00
										 |  |  |     def __getstate__(self) -> dict: | 
					
						
							| 
									
										
										
										
											2023-10-10 19:11:49 +02:00
										 |  |  |         d = dict(self.__dict__) | 
					
						
							|  |  |  |         del d["_actor_future"] | 
					
						
							|  |  |  |         return d | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-13 12:25:28 +02:00
										 |  |  |     def __setstate__(self, state: dict) -> None: | 
					
						
							| 
									
										
										
										
											2023-10-10 19:11:49 +02:00
										 |  |  |         self.__dict__ = state | 
					
						
							|  |  |  |         self._actor_future = ActorFuture() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-13 12:25:28 +02:00
										 |  |  |     def _tostring_excludes(self) -> list[str]: | 
					
						
							| 
									
										
										
										
											2023-10-10 19:11:49 +02:00
										 |  |  |         return [*super()._tostring_excludes(), "_actor_future"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: | 
					
						
							|  |  |  |         module = self.actor_factory.create_module(envs, device) | 
					
						
							|  |  |  |         self._actor_future.actor = module | 
					
						
							|  |  |  |         return module | 
					
						
							| 
									
										
										
										
											2023-10-11 15:31:38 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory): | 
					
						
							|  |  |  |     def __init__(self, actor_factory: ActorFactory): | 
					
						
							|  |  |  |         self.actor_factory = actor_factory | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: | 
					
						
							|  |  |  |         actor = self.actor_factory.create_module(envs, device) | 
					
						
							|  |  |  |         assert isinstance(actor, BaseActor) | 
					
						
							|  |  |  |         return IntermediateModule(actor, actor.get_output_dim()) |