Policy objects are now parametrised by converting the parameter dataclass instances to kwargs, using some injectable conversions along the way
36 lines
1.0 KiB
Python
36 lines
1.0 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tianshou.highlevel.env import Environments
|
|
from tianshou.highlevel.module import TDevice
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
|
|
|
|
|
class AutoAlphaFactory(ABC):
|
|
@abstractmethod
|
|
def create_auto_alpha(
|
|
self,
|
|
envs: Environments,
|
|
optim_factory: OptimizerFactory,
|
|
device: TDevice,
|
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
|
pass
|
|
|
|
|
|
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
|
|
def __init__(self, lr: float = 3e-4):
|
|
self.lr = lr
|
|
|
|
def create_auto_alpha(
|
|
self,
|
|
envs: Environments,
|
|
optim_factory: OptimizerFactory,
|
|
device: TDevice,
|
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
|
target_entropy = -np.prod(envs.get_action_shape())
|
|
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
|
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
|
return target_entropy, log_alpha, alpha_optim
|