Tianshou/tianshou/highlevel/module/intermediate.py

28 lines
895 B
Python
Raw Normal View History

from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.utils.string import ToStringMixin
@dataclass
class IntermediateModule:
"""Container for a module which computes an intermediate representation (with a known dimension)."""
module: torch.nn.Module
output_dim: int
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
"""Factory for the generation of a module which computes an intermediate representation."""
@abstractmethod
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
pass
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return self.create_intermediate_module(envs, device).module