Dominik Jain 6b6d9ea609 Add support for discrete PPO
* Refactored module `module` (split into submodules)
* Basic support for discrete environments
* Implement Atari env. factory
* Implement DQN-based actor factory
* Implement notion of reusing agent preprocessing network for critic
* Add example atari_ppo_hl
2023-10-18 20:44:16 +02:00

36 lines
1.0 KiB
Python

from abc import ABC, abstractmethod
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.optim import OptimizerFactory
class AutoAlphaFactory(ABC):
@abstractmethod
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
pass
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4):
self.lr = lr
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim