DQNExperimentBuilder: Use IntermediateModuleFactory instead of ActorFactory

(similar to IQN implementation)
This commit is contained in:
Dominik Jain 2023-10-16 18:46:37 +02:00
parent 83048788a1
commit ae4850692f
3 changed files with 10 additions and 14 deletions

View File

@ -5,7 +5,7 @@ import os
from jsonargparse import CLI from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
ActorFactoryAtariPlainDQN, IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures, IntermediateModuleFactoryAtariDQNFeatures,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
@ -106,7 +106,7 @@ def main(
target_update_freq=target_update_freq, target_update_freq=target_update_freq,
), ),
) )
.with_actor_factory(ActorFactoryAtariPlainDQN()) .with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train(TrainEpochCallback()) .with_trainer_epoch_callback_train(TrainEpochCallback())
.with_trainer_epoch_callback_test(TestEpochCallback()) .with_trainer_epoch_callback_test(TestEpochCallback())
.with_trainer_stop_callback(AtariStopCallback(task)) .with_trainer_stop_callback(AtariStopCallback(task))

View File

@ -231,15 +231,6 @@ class QRDQN(DQN):
return obs, state return obs, state
class ActorFactoryAtariPlainDQN(ActorFactory):
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return DQN(
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
).to(device)
class ActorFactoryAtariDQN(ActorFactory): class ActorFactoryAtariDQN(ActorFactory):
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool): def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool):
self.hidden_size = hidden_size self.hidden_size = hidden_size

View File

@ -849,7 +849,6 @@ class TRPOExperimentBuilder(
class DQNExperimentBuilder( class DQNExperimentBuilder(
ExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory,
): ):
def __init__( def __init__(
self, self,
@ -858,18 +857,24 @@ class DQNExperimentBuilder(
sampling_config: SamplingConfig | None = None, sampling_config: SamplingConfig | None = None,
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
self._params: DQNParams = DQNParams() self._params: DQNParams = DQNParams()
self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
)
def with_dqn_params(self, params: DQNParams) -> Self: def with_dqn_params(self, params: DQNParams) -> Self:
self._params = params self._params = params
return self return self
def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self:
self._model_factory = module_factory
return self
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
return DQNAgentFactory( return DQNAgentFactory(
self._params, self._params,
self._sampling_config, self._sampling_config,
self._get_actor_factory(), self._model_factory,
self._get_optim_factory(), self._get_optim_factory(),
) )