2023-09-25 17:56:37 +02:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from tianshou.highlevel.env import Environments
|
2023-09-28 20:07:52 +02:00
|
|
|
from tianshou.highlevel.module.core import TDevice
|
2023-09-25 17:56:37 +02:00
|
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
2023-11-07 10:54:22 +01:00
|
|
|
from tianshou.utils.string import ToStringMixin
|
2023-09-25 17:56:37 +02:00
|
|
|
|
|
|
|
|
2023-10-05 13:15:24 +02:00
|
|
|
class AutoAlphaFactory(ToStringMixin, ABC):
|
2023-09-25 17:56:37 +02:00
|
|
|
@abstractmethod
|
|
|
|
def create_auto_alpha(
|
|
|
|
self,
|
|
|
|
envs: Environments,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
device: TDevice,
|
|
|
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-09-27 17:20:35 +02:00
|
|
|
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
2023-09-25 17:56:37 +02:00
|
|
|
def __init__(self, lr: float = 3e-4):
|
|
|
|
self.lr = lr
|
|
|
|
|
|
|
|
def create_auto_alpha(
|
|
|
|
self,
|
|
|
|
envs: Environments,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
device: TDevice,
|
|
|
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
2023-10-09 17:22:52 +02:00
|
|
|
target_entropy = float(-np.prod(envs.get_action_shape()))
|
2023-09-25 17:56:37 +02:00
|
|
|
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
|
|
|
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
|
|
|
return target_entropy, log_alpha, alpha_optim
|