From ae4850692fe6b78b6d5228aadd0bb6e022968e0a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 16 Oct 2023 18:46:37 +0200 Subject: [PATCH] DQNExperimentBuilder: Use IntermediateModuleFactory instead of ActorFactory (similar to IQN implementation) --- examples/atari/atari_dqn_hl.py | 4 ++-- examples/atari/atari_network.py | 9 --------- tianshou/highlevel/experiment.py | 11 ++++++++--- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 466520e..fc0483f 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -5,7 +5,7 @@ import os from jsonargparse import CLI from examples.atari.atari_network import ( - ActorFactoryAtariPlainDQN, + IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback @@ -106,7 +106,7 @@ def main( target_update_freq=target_update_freq, ), ) - .with_actor_factory(ActorFactoryAtariPlainDQN()) + .with_model_factory(IntermediateModuleFactoryAtariDQN()) .with_trainer_epoch_callback_train(TrainEpochCallback()) .with_trainer_epoch_callback_test(TestEpochCallback()) .with_trainer_stop_callback(AtariStopCallback(task)) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index a371621..0ff9b50 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -231,15 +231,6 @@ class QRDQN(DQN): return obs, state -class ActorFactoryAtariPlainDQN(ActorFactory): - def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: - return DQN( - *envs.get_observation_shape(), - envs.get_action_shape(), - device=device, - ).to(device) - - class ActorFactoryAtariDQN(ActorFactory): def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool): self.hidden_size = hidden_size diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index f22ea6d..730bc61 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -849,7 +849,6 @@ class TRPOExperimentBuilder( class DQNExperimentBuilder( ExperimentBuilder, - _BuilderMixinActorFactory, ): def __init__( self, @@ -858,18 +857,24 @@ class DQNExperimentBuilder( sampling_config: SamplingConfig | None = None, ): super().__init__(env_factory, experiment_config, sampling_config) - _BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED) self._params: DQNParams = DQNParams() + self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory( + ActorFactoryDefault(ContinuousActorType.UNSUPPORTED), + ) def with_dqn_params(self, params: DQNParams) -> Self: self._params = params return self + def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self: + self._model_factory = module_factory + return self + def _create_agent_factory(self) -> AgentFactory: return DQNAgentFactory( self._params, self._sampling_config, - self._get_actor_factory(), + self._model_factory, self._get_optim_factory(), )