28 lines
		
	
	
		
			895 B
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			28 lines
		
	
	
		
			895 B
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | from abc import ABC, abstractmethod | ||
|  | from dataclasses import dataclass | ||
|  | 
 | ||
|  | import torch | ||
|  | 
 | ||
|  | from tianshou.highlevel.env import Environments | ||
|  | from tianshou.highlevel.module.core import ModuleFactory, TDevice | ||
|  | from tianshou.utils.string import ToStringMixin | ||
|  | 
 | ||
|  | 
 | ||
|  | @dataclass | ||
|  | class IntermediateModule: | ||
|  |     """Container for a module which computes an intermediate representation (with a known dimension).""" | ||
|  | 
 | ||
|  |     module: torch.nn.Module | ||
|  |     output_dim: int | ||
|  | 
 | ||
|  | 
 | ||
|  | class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): | ||
|  |     """Factory for the generation of a module which computes an intermediate representation.""" | ||
|  | 
 | ||
|  |     @abstractmethod | ||
|  |     def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: | ||
|  |         pass | ||
|  | 
 | ||
|  |     def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: | ||
|  |         return self.create_intermediate_module(envs, device).module |