Tianshou/tianshou/highlevel/params/policy_wrapper.py
Dominik Jain dae4000cd2 Revert "Depend on sensAI instead of copying its utils (logging, string)"
This reverts commit fdb0eba93d81fa5e698770b4f7088c87fc1238da.
2023-11-08 19:11:39 +01:00

77 lines
2.5 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Generic, TypeVar
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
from tianshou.utils.string import ToStringMixin
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC):
@abstractmethod
def create_wrapped_policy(
self,
policy: BasePolicy,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> TPolicyOut:
pass
class PolicyWrapperFactoryIntrinsicCuriosity(
PolicyWrapperFactory[ICMPolicy],
):
def __init__(
self,
*,
feature_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int],
lr: float,
lr_scale: float,
reward_scale: float,
forward_loss_weight: float,
):
self.feature_net_factory = feature_net_factory
self.hidden_sizes = hidden_sizes
self.lr = lr
self.lr_scale = lr_scale
self.reward_scale = reward_scale
self.forward_loss_weight = forward_loss_weight
def create_wrapped_policy(
self,
policy: BasePolicy,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> ICMPolicy:
feature_net = self.feature_net_factory.create_intermediate_module(envs, device)
action_dim = envs.get_action_shape()
if not isinstance(action_dim, int):
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.module,
feature_dim,
action_dim,
hidden_sizes=self.hidden_sizes,
device=device,
)
icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr)
return ICMPolicy(
policy=policy,
model=icm_net,
optim=icm_optim,
action_space=envs.get_action_space(),
lr_scale=self.lr_scale,
reward_scale=self.reward_scale,
forward_loss_weight=self.forward_loss_weight,
).to(device)