Tianshou/tianshou/highlevel/params/policy_wrapper.py

78 lines
2.5 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Generic, TypeVar
from sensai.util.string import ToStringMixin
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC):
@abstractmethod
def create_wrapped_policy(
self,
policy: BasePolicy,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> TPolicyOut:
pass
class PolicyWrapperFactoryIntrinsicCuriosity(
PolicyWrapperFactory[ICMPolicy],
):
def __init__(
self,
*,
feature_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int],
lr: float,
lr_scale: float,
reward_scale: float,
forward_loss_weight: float,
):
self.feature_net_factory = feature_net_factory
self.hidden_sizes = hidden_sizes
self.lr = lr
self.lr_scale = lr_scale
self.reward_scale = reward_scale
self.forward_loss_weight = forward_loss_weight
def create_wrapped_policy(
self,
policy: BasePolicy,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> ICMPolicy:
feature_net = self.feature_net_factory.create_intermediate_module(envs, device)
action_dim = envs.get_action_shape()
if not isinstance(action_dim, int):
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.module,
feature_dim,
action_dim,
hidden_sizes=self.hidden_sizes,
device=device,
)
icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr)
return ICMPolicy(
policy=policy,
model=icm_net,
optim=icm_optim,
action_space=envs.get_action_space(),
lr_scale=self.lr_scale,
reward_scale=self.reward_scale,
forward_loss_weight=self.forward_loss_weight,
).to(device)