Dominik Jain 78b6dd1f49 Adapt class naming scheme
* Use prefix convention (subclasses have superclass names as prefix) to
  facilitate discoverability of relevant classes via IDE autocompletion
* Use dual naming, adding an alternative concise name that omits the
  precise OO semantics and retains only the essential part of the name
  (which can be more pleasing to users not accustomed to
  convoluted OO naming)
2023-10-18 20:44:16 +02:00

36 lines
1.0 KiB
Python

from abc import ABC, abstractmethod
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module import TDevice
from tianshou.highlevel.optim import OptimizerFactory
class AutoAlphaFactory(ABC):
@abstractmethod
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
pass
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4):
self.lr = lr
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim