Dominik Jain 367778d37f Improve high-level policy parametrisation
Policy objects are now parametrised by converting the parameter
dataclass instances to kwargs, using some injectable conversions
along the way
2023-10-18 20:44:16 +02:00

36 lines
1.0 KiB
Python

from abc import ABC, abstractmethod
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import TDevice
from tianshou.highlevel.optim import OptimizerFactory
class AutoAlphaFactory(ABC):
@abstractmethod
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
pass
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4):
self.lr = lr
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim