DQNExperimentBuilder: Use IntermediateModuleFactory instead of ActorFactory
(similar to IQN implementation)
This commit is contained in:
parent
83048788a1
commit
ae4850692f
@ -5,7 +5,7 @@ import os
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
ActorFactoryAtariPlainDQN,
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
@ -106,7 +106,7 @@ def main(
|
||||
target_update_freq=target_update_freq,
|
||||
),
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariPlainDQN())
|
||||
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
||||
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
||||
.with_trainer_epoch_callback_test(TestEpochCallback())
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
|
@ -231,15 +231,6 @@ class QRDQN(DQN):
|
||||
return obs, state
|
||||
|
||||
|
||||
class ActorFactoryAtariPlainDQN(ActorFactory):
|
||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||
return DQN(
|
||||
*envs.get_observation_shape(),
|
||||
envs.get_action_shape(),
|
||||
device=device,
|
||||
).to(device)
|
||||
|
||||
|
||||
class ActorFactoryAtariDQN(ActorFactory):
|
||||
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool):
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -849,7 +849,6 @@ class TRPOExperimentBuilder(
|
||||
|
||||
class DQNExperimentBuilder(
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@ -858,18 +857,24 @@ class DQNExperimentBuilder(
|
||||
sampling_config: SamplingConfig | None = None,
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
||||
self._params: DQNParams = DQNParams()
|
||||
self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory(
|
||||
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
||||
)
|
||||
|
||||
def with_dqn_params(self, params: DQNParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self:
|
||||
self._model_factory = module_factory
|
||||
return self
|
||||
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return DQNAgentFactory(
|
||||
self._params,
|
||||
self._sampling_config,
|
||||
self._get_actor_factory(),
|
||||
self._model_factory,
|
||||
self._get_optim_factory(),
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user