from abc import ABC, abstractmethod from dataclasses import dataclass import torch from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import ModuleFactory, TDevice from tianshou.utils.string import ToStringMixin @dataclass class IntermediateModule: """Container for a module which computes an intermediate representation (with a known dimension).""" module: torch.nn.Module output_dim: int class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): """Factory for the generation of a module which computes an intermediate representation.""" @abstractmethod def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: pass def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: return self.create_intermediate_module(envs, device).module