Dominik Jain dae4000cd2 Revert "Depend on sensAI instead of copying its utils (logging, string)"
This reverts commit fdb0eba93d81fa5e698770b4f7088c87fc1238da.
2023-11-08 19:11:39 +01:00

37 lines
1.1 KiB
Python

from abc import ABC, abstractmethod
import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.string import ToStringMixin
class AutoAlphaFactory(ToStringMixin, ABC):
@abstractmethod
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
pass
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4):
self.lr = lr
def create_auto_alpha(
self,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = float(-np.prod(envs.get_action_shape()))
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim