Dominik Jain 6b6d9ea609 Add support for discrete PPO
* Refactored module `module` (split into submodules)
* Basic support for discrete environments
* Implement Atari env. factory
* Implement DQN-based actor factory
* Implement notion of reusing agent preprocessing network for critic
* Add example atari_ppo_hl
2023-10-18 20:44:16 +02:00

45 lines
1.2 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net
TDevice: TypeAlias = str | int | torch.device
def init_linear_orthogonal(module: torch.nn.Module):
"""Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
:param module: the module whose submodules are to be processed
"""
for m in module.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
@dataclass
class Module:
module: torch.nn.Module
output_dim: int
class ModuleFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> Module:
pass
class ModuleFactoryNet(ModuleFactory):
def __init__(self, hidden_sizes: int | Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> Module:
module = Net(envs.get_observation_shape())
return Module(module, module.output_dim)