Tianshou/tianshou/highlevel/params/policy_wrapper.py
Dominik Jain a8a367c42d Support IQN in high-level API
* Add example atari_iqn_hl
* Factor out trainer callbacks to new module atari_callbacks
* Extract base class for DQN-based agent factories
* Improved module factory interface design, achieving higher generality
2023-10-18 20:44:17 +02:00

77 lines
2.5 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Generic, TypeVar
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
from tianshou.utils.string import ToStringMixin
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
@abstractmethod
def create_wrapped_policy(
self,
policy: TPolicyIn,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> TPolicyOut:
pass
class PolicyWrapperFactoryIntrinsicCuriosity(
Generic[TPolicyIn],
PolicyWrapperFactory[TPolicyIn, ICMPolicy],
):
def __init__(
self,
feature_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int],
lr: float,
lr_scale: float,
reward_scale: float,
forward_loss_weight: float,
):
self.feature_net_factory = feature_net_factory
self.hidden_sizes = hidden_sizes
self.lr = lr
self.lr_scale = lr_scale
self.reward_scale = reward_scale
self.forward_loss_weight = forward_loss_weight
def create_wrapped_policy(
self,
policy: TPolicyIn,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
) -> ICMPolicy:
feature_net = self.feature_net_factory.create_intermediate_module(envs, device)
action_dim = envs.get_action_shape()
if not isinstance(action_dim, int):
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.module,
feature_dim,
action_dim,
hidden_sizes=self.hidden_sizes,
device=device,
)
icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr)
return ICMPolicy(
policy=policy,
model=icm_net,
optim=icm_optim,
action_space=envs.get_action_space(),
lr_scale=self.lr_scale,
reward_scale=self.reward_scale,
forward_loss_weight=self.forward_loss_weight,
).to(device)