78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections.abc import Sequence
|
|
from typing import Generic, TypeVar
|
|
|
|
from sensai.util.string import ToStringMixin
|
|
|
|
from tianshou.highlevel.env import Environments
|
|
from tianshou.highlevel.module.core import TDevice
|
|
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
|
from tianshou.policy import BasePolicy, ICMPolicy
|
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
|
|
|
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
|
|
|
|
|
|
class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC):
|
|
@abstractmethod
|
|
def create_wrapped_policy(
|
|
self,
|
|
policy: BasePolicy,
|
|
envs: Environments,
|
|
optim_factory: OptimizerFactory,
|
|
device: TDevice,
|
|
) -> TPolicyOut:
|
|
pass
|
|
|
|
|
|
class PolicyWrapperFactoryIntrinsicCuriosity(
|
|
PolicyWrapperFactory[ICMPolicy],
|
|
):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
feature_net_factory: IntermediateModuleFactory,
|
|
hidden_sizes: Sequence[int],
|
|
lr: float,
|
|
lr_scale: float,
|
|
reward_scale: float,
|
|
forward_loss_weight: float,
|
|
):
|
|
self.feature_net_factory = feature_net_factory
|
|
self.hidden_sizes = hidden_sizes
|
|
self.lr = lr
|
|
self.lr_scale = lr_scale
|
|
self.reward_scale = reward_scale
|
|
self.forward_loss_weight = forward_loss_weight
|
|
|
|
def create_wrapped_policy(
|
|
self,
|
|
policy: BasePolicy,
|
|
envs: Environments,
|
|
optim_factory: OptimizerFactory,
|
|
device: TDevice,
|
|
) -> ICMPolicy:
|
|
feature_net = self.feature_net_factory.create_intermediate_module(envs, device)
|
|
action_dim = envs.get_action_shape()
|
|
if not isinstance(action_dim, int):
|
|
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
|
|
feature_dim = feature_net.output_dim
|
|
icm_net = IntrinsicCuriosityModule(
|
|
feature_net.module,
|
|
feature_dim,
|
|
action_dim,
|
|
hidden_sizes=self.hidden_sizes,
|
|
device=device,
|
|
)
|
|
icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr)
|
|
return ICMPolicy(
|
|
policy=policy,
|
|
model=icm_net,
|
|
optim=icm_optim,
|
|
action_space=envs.get_action_space(),
|
|
lr_scale=self.lr_scale,
|
|
reward_scale=self.reward_scale,
|
|
forward_loss_weight=self.forward_loss_weight,
|
|
).to(device)
|