DQNExperimentBuilder: Use IntermediateModuleFactory instead of ActorFactory
(similar to IQN implementation)
This commit is contained in:
parent
83048788a1
commit
ae4850692f
@ -5,7 +5,7 @@ import os
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
ActorFactoryAtariPlainDQN,
|
IntermediateModuleFactoryAtariDQN,
|
||||||
IntermediateModuleFactoryAtariDQNFeatures,
|
IntermediateModuleFactoryAtariDQNFeatures,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
@ -106,7 +106,7 @@ def main(
|
|||||||
target_update_freq=target_update_freq,
|
target_update_freq=target_update_freq,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory(ActorFactoryAtariPlainDQN())
|
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
||||||
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
||||||
.with_trainer_epoch_callback_test(TestEpochCallback())
|
.with_trainer_epoch_callback_test(TestEpochCallback())
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
|
@ -231,15 +231,6 @@ class QRDQN(DQN):
|
|||||||
return obs, state
|
return obs, state
|
||||||
|
|
||||||
|
|
||||||
class ActorFactoryAtariPlainDQN(ActorFactory):
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
|
||||||
return DQN(
|
|
||||||
*envs.get_observation_shape(),
|
|
||||||
envs.get_action_shape(),
|
|
||||||
device=device,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
|
|
||||||
class ActorFactoryAtariDQN(ActorFactory):
|
class ActorFactoryAtariDQN(ActorFactory):
|
||||||
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool):
|
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
@ -849,7 +849,6 @@ class TRPOExperimentBuilder(
|
|||||||
|
|
||||||
class DQNExperimentBuilder(
|
class DQNExperimentBuilder(
|
||||||
ExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory,
|
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -858,18 +857,24 @@ class DQNExperimentBuilder(
|
|||||||
sampling_config: SamplingConfig | None = None,
|
sampling_config: SamplingConfig | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
|
||||||
self._params: DQNParams = DQNParams()
|
self._params: DQNParams = DQNParams()
|
||||||
|
self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory(
|
||||||
|
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
||||||
|
)
|
||||||
|
|
||||||
def with_dqn_params(self, params: DQNParams) -> Self:
|
def with_dqn_params(self, params: DQNParams) -> Self:
|
||||||
self._params = params
|
self._params = params
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self:
|
||||||
|
self._model_factory = module_factory
|
||||||
|
return self
|
||||||
|
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
return DQNAgentFactory(
|
return DQNAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
self._sampling_config,
|
self._sampling_config,
|
||||||
self._get_actor_factory(),
|
self._model_factory,
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user