* Use prefix convention (subclasses have superclass names as prefix) to facilitate discoverability of relevant classes via IDE autocompletion * Use dual naming, adding an alternative concise name that omits the precise OO semantics and retains only the essential part of the name (which can be more pleasing to users not accustomed to convoluted OO naming)
36 lines
1.0 KiB
Python
36 lines
1.0 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tianshou.highlevel.env import Environments
|
|
from tianshou.highlevel.module import TDevice
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
|
|
|
|
|
class AutoAlphaFactory(ABC):
|
|
@abstractmethod
|
|
def create_auto_alpha(
|
|
self,
|
|
envs: Environments,
|
|
optim_factory: OptimizerFactory,
|
|
device: TDevice,
|
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
|
pass
|
|
|
|
|
|
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
|
def __init__(self, lr: float = 3e-4):
|
|
self.lr = lr
|
|
|
|
def create_auto_alpha(
|
|
self,
|
|
envs: Environments,
|
|
optim_factory: OptimizerFactory,
|
|
device: TDevice,
|
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
|
target_entropy = -np.prod(envs.get_action_shape())
|
|
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
|
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
|
return target_entropy, log_alpha, alpha_optim
|