DQNExperimentBuilder: Use IntermediateModuleFactory instead of ActorFactory

(similar to IQN implementation)
This commit is contained in:
Dominik Jain 2023-10-16 18:46:37 +02:00
parent 83048788a1
commit ae4850692f
3 changed files with 10 additions and 14 deletions

View File

@ -5,7 +5,7 @@ import os
from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariPlainDQN,
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
@ -106,7 +106,7 @@ def main(
target_update_freq=target_update_freq,
),
)
.with_actor_factory(ActorFactoryAtariPlainDQN())
.with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train(TrainEpochCallback())
.with_trainer_epoch_callback_test(TestEpochCallback())
.with_trainer_stop_callback(AtariStopCallback(task))

View File

@ -231,15 +231,6 @@ class QRDQN(DQN):
return obs, state
class ActorFactoryAtariPlainDQN(ActorFactory):
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return DQN(
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
).to(device)
class ActorFactoryAtariDQN(ActorFactory):
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool):
self.hidden_size = hidden_size

View File

@ -849,7 +849,6 @@ class TRPOExperimentBuilder(
class DQNExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory,
):
def __init__(
self,
@ -858,18 +857,24 @@ class DQNExperimentBuilder(
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
self._params: DQNParams = DQNParams()
self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
)
def with_dqn_params(self, params: DQNParams) -> Self:
self._params = params
return self
def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self:
self._model_factory = module_factory
return self
def _create_agent_factory(self) -> AgentFactory:
return DQNAgentFactory(
self._params,
self._sampling_config,
self._get_actor_factory(),
self._model_factory,
self._get_optim_factory(),
)