77 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from abc import ABC, abstractmethod
 | 
						|
from collections.abc import Sequence
 | 
						|
from typing import Generic, TypeVar
 | 
						|
 | 
						|
from tianshou.highlevel.env import Environments
 | 
						|
from tianshou.highlevel.module.core import TDevice
 | 
						|
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
 | 
						|
from tianshou.highlevel.optim import OptimizerFactory
 | 
						|
from tianshou.policy import BasePolicy, ICMPolicy
 | 
						|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
 | 
						|
from tianshou.utils.string import ToStringMixin
 | 
						|
 | 
						|
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
 | 
						|
 | 
						|
 | 
						|
class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC):
 | 
						|
    @abstractmethod
 | 
						|
    def create_wrapped_policy(
 | 
						|
        self,
 | 
						|
        policy: BasePolicy,
 | 
						|
        envs: Environments,
 | 
						|
        optim_factory: OptimizerFactory,
 | 
						|
        device: TDevice,
 | 
						|
    ) -> TPolicyOut:
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class PolicyWrapperFactoryIntrinsicCuriosity(
 | 
						|
    PolicyWrapperFactory[ICMPolicy],
 | 
						|
):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        *,
 | 
						|
        feature_net_factory: IntermediateModuleFactory,
 | 
						|
        hidden_sizes: Sequence[int],
 | 
						|
        lr: float,
 | 
						|
        lr_scale: float,
 | 
						|
        reward_scale: float,
 | 
						|
        forward_loss_weight: float,
 | 
						|
    ):
 | 
						|
        self.feature_net_factory = feature_net_factory
 | 
						|
        self.hidden_sizes = hidden_sizes
 | 
						|
        self.lr = lr
 | 
						|
        self.lr_scale = lr_scale
 | 
						|
        self.reward_scale = reward_scale
 | 
						|
        self.forward_loss_weight = forward_loss_weight
 | 
						|
 | 
						|
    def create_wrapped_policy(
 | 
						|
        self,
 | 
						|
        policy: BasePolicy,
 | 
						|
        envs: Environments,
 | 
						|
        optim_factory: OptimizerFactory,
 | 
						|
        device: TDevice,
 | 
						|
    ) -> ICMPolicy:
 | 
						|
        feature_net = self.feature_net_factory.create_intermediate_module(envs, device)
 | 
						|
        action_dim = envs.get_action_shape()
 | 
						|
        if not isinstance(action_dim, int):
 | 
						|
            raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
 | 
						|
        feature_dim = feature_net.output_dim
 | 
						|
        icm_net = IntrinsicCuriosityModule(
 | 
						|
            feature_net.module,
 | 
						|
            feature_dim,
 | 
						|
            action_dim,
 | 
						|
            hidden_sizes=self.hidden_sizes,
 | 
						|
            device=device,
 | 
						|
        )
 | 
						|
        icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr)
 | 
						|
        return ICMPolicy(
 | 
						|
            policy=policy,
 | 
						|
            model=icm_net,
 | 
						|
            optim=icm_optim,
 | 
						|
            action_space=envs.get_action_space(),
 | 
						|
            lr_scale=self.lr_scale,
 | 
						|
            reward_scale=self.reward_scale,
 | 
						|
            forward_loss_weight=self.forward_loss_weight,
 | 
						|
        ).to(device)
 |