36 lines
1.0 KiB
Python
36 lines
1.0 KiB
Python
|
from abc import ABC, abstractmethod
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
from tianshou.highlevel.env import Environments
|
||
|
from tianshou.highlevel.module import TDevice
|
||
|
from tianshou.highlevel.optim import OptimizerFactory
|
||
|
|
||
|
|
||
|
class AutoAlphaFactory(ABC):
|
||
|
@abstractmethod
|
||
|
def create_auto_alpha(
|
||
|
self,
|
||
|
envs: Environments,
|
||
|
optim_factory: OptimizerFactory,
|
||
|
device: TDevice,
|
||
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||
|
pass
|
||
|
|
||
|
|
||
|
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
|
||
|
def __init__(self, lr: float = 3e-4):
|
||
|
self.lr = lr
|
||
|
|
||
|
def create_auto_alpha(
|
||
|
self,
|
||
|
envs: Environments,
|
||
|
optim_factory: OptimizerFactory,
|
||
|
device: TDevice,
|
||
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||
|
target_entropy = -np.prod(envs.get_action_shape())
|
||
|
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||
|
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
||
|
return target_entropy, log_alpha, alpha_optim
|