Change high-level DQN interface to expect an actor instead of a critic,
because that is what is functionally required
This commit is contained in:
parent
1cba589bd4
commit
b54fcd12cb
@ -6,7 +6,7 @@ import os
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
CriticFactoryAtariDQN,
|
||||
ActorFactoryAtariPlainDQN,
|
||||
FeatureNetFactoryDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
@ -103,7 +103,7 @@ def main(
|
||||
target_update_freq=target_update_freq,
|
||||
),
|
||||
)
|
||||
.with_critic_factory(CriticFactoryAtariDQN())
|
||||
.with_actor_factory(ActorFactoryAtariPlainDQN())
|
||||
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
||||
.with_trainer_epoch_callback_test(TestEpochCallback())
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
|
@ -8,8 +8,6 @@ from torch import nn
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.actor import ActorFactory
|
||||
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
||||
from tianshou.highlevel.module.critic import CriticFactory
|
||||
from tianshou.utils.net.common import BaseActor
|
||||
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||
|
||||
|
||||
@ -227,14 +225,8 @@ class QRDQN(DQN):
|
||||
return obs, state
|
||||
|
||||
|
||||
class CriticFactoryAtariDQN(CriticFactory):
|
||||
def create_module(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
) -> torch.nn.Module:
|
||||
assert use_action
|
||||
class ActorFactoryAtariPlainDQN(ActorFactory):
|
||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||
return DQN(
|
||||
*envs.get_observation_shape(),
|
||||
envs.get_action_shape(),
|
||||
@ -243,17 +235,18 @@ class CriticFactoryAtariDQN(CriticFactory):
|
||||
|
||||
|
||||
class ActorFactoryAtariDQN(ActorFactory):
|
||||
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
|
||||
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool):
|
||||
self.hidden_size = hidden_size
|
||||
self.scale_obs = scale_obs
|
||||
self.features_only = features_only
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Actor:
|
||||
net_cls = scale_obs(DQN) if self.scale_obs else DQN
|
||||
net = net_cls(
|
||||
*envs.get_observation_shape(),
|
||||
envs.get_action_shape(),
|
||||
device=device,
|
||||
features_only=True,
|
||||
features_only=self.features_only,
|
||||
output_dim=self.hidden_size,
|
||||
layer_init=layer_init,
|
||||
)
|
||||
|
@ -96,7 +96,7 @@ def main(
|
||||
else None,
|
||||
),
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
|
||||
.with_critic_factory_use_actor()
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
)
|
||||
|
@ -44,7 +44,7 @@ from tianshou.policy import (
|
||||
)
|
||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.net.common import ActorCritic, BaseActor
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
@ -285,6 +285,10 @@ class _ActorCriticMixin:
|
||||
raise ValueError(
|
||||
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
|
||||
)
|
||||
if not isinstance(actor, BaseActor):
|
||||
raise ValueError(
|
||||
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
|
||||
)
|
||||
if envs.get_type().is_discrete():
|
||||
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||
elif envs.get_type().is_continuous():
|
||||
@ -444,17 +448,17 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
self,
|
||||
params: DQNParams,
|
||||
sampling_config: RLSamplingConfig,
|
||||
critic_factory: CriticFactory,
|
||||
actor_factory: ActorFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
self.params = params
|
||||
self.critic_factory = critic_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=True)
|
||||
optim = self.optim_factory.create_optimizer(critic, self.params.lr)
|
||||
model = self.actor_factory.create_module(envs, device)
|
||||
optim = self.optim_factory.create_optimizer(model, self.params.lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
envs=envs,
|
||||
@ -467,7 +471,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
# noinspection PyTypeChecker
|
||||
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
|
||||
return DQNPolicy(
|
||||
model=critic,
|
||||
model=model,
|
||||
optim=optim,
|
||||
action_space=action_space,
|
||||
observation_space=envs.get_observation_space(),
|
||||
|
@ -461,7 +461,7 @@ class PPOExperimentBuilder(
|
||||
|
||||
class DQNExperimentBuilder(
|
||||
RLExperimentBuilder,
|
||||
_BuilderMixinSingleCriticFactory,
|
||||
_BuilderMixinActorFactory,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@ -470,7 +470,7 @@ class DQNExperimentBuilder(
|
||||
sampling_config: RLSamplingConfig,
|
||||
):
|
||||
super().__init__(experiment_config, env_factory, sampling_config)
|
||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
||||
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
||||
self._params: DQNParams = DQNParams()
|
||||
|
||||
def with_dqn_params(self, params: DQNParams) -> Self:
|
||||
@ -482,7 +482,7 @@ class DQNExperimentBuilder(
|
||||
return DQNAgentFactory(
|
||||
self._params,
|
||||
self._sampling_config,
|
||||
self._get_critic_factory(0),
|
||||
self._get_actor_factory(),
|
||||
self._get_optim_factory(),
|
||||
)
|
||||
|
||||
|
@ -14,11 +14,12 @@ from tianshou.utils.string import ToStringMixin
|
||||
class ContinuousActorType:
|
||||
GAUSSIAN = "gaussian"
|
||||
DETERMINISTIC = "deterministic"
|
||||
UNSUPPORTED = "unsupported"
|
||||
|
||||
|
||||
class ActorFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -67,11 +68,14 @@ class ActorFactoryDefault(ActorFactory):
|
||||
)
|
||||
case ContinuousActorType.DETERMINISTIC:
|
||||
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
|
||||
case ContinuousActorType.UNSUPPORTED:
|
||||
raise ValueError("Continuous action spaces are not supported by the algorithm")
|
||||
case _:
|
||||
raise ValueError(self.continuous_actor_type)
|
||||
return factory.create_module(envs, device)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
raise NotImplementedError
|
||||
factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES)
|
||||
return factory.create_module(envs, device)
|
||||
else:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
|
||||
|
@ -5,7 +5,7 @@ from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||
from tianshou.utils.net import continuous
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
@ -30,16 +30,13 @@ class CriticFactoryDefault(CriticFactory):
|
||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
||||
return factory.create_module(envs, device, use_action)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
raise NotImplementedError
|
||||
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
||||
return factory.create_module(envs, device, use_action)
|
||||
else:
|
||||
raise ValueError(f"{env_type} not supported")
|
||||
|
||||
|
||||
class CriticFactoryContinuous(CriticFactory, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
||||
class CriticFactoryContinuousNet(CriticFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
@ -56,3 +53,22 @@ class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
||||
critic = continuous.Critic(net_c, device=device).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
|
||||
|
||||
class CriticFactoryDiscreteNet(CriticFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
envs.get_observation_shape(),
|
||||
action_shape=action_shape,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
concat=use_action,
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
)
|
||||
critic = discrete.Critic(net_c, device=device).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
|
Loading…
x
Reference in New Issue
Block a user