Change high-level DQN interface to expect an actor instead of a critic,

because that is what is functionally required
This commit is contained in:
Dominik Jain 2023-10-05 19:21:08 +02:00
parent 1cba589bd4
commit b54fcd12cb
7 changed files with 51 additions and 34 deletions

View File

@ -6,7 +6,7 @@ import os
from jsonargparse import CLI from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
CriticFactoryAtariDQN, ActorFactoryAtariPlainDQN,
FeatureNetFactoryDQN, FeatureNetFactoryDQN,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
@ -103,7 +103,7 @@ def main(
target_update_freq=target_update_freq, target_update_freq=target_update_freq,
), ),
) )
.with_critic_factory(CriticFactoryAtariDQN()) .with_actor_factory(ActorFactoryAtariPlainDQN())
.with_trainer_epoch_callback_train(TrainEpochCallback()) .with_trainer_epoch_callback_train(TrainEpochCallback())
.with_trainer_epoch_callback_test(TestEpochCallback()) .with_trainer_epoch_callback_test(TestEpochCallback())
.with_trainer_stop_callback(AtariStopCallback(task)) .with_trainer_stop_callback(AtariStopCallback(task))

View File

@ -8,8 +8,6 @@ from torch import nn
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.utils.net.common import BaseActor
from tianshou.utils.net.discrete import Actor, NoisyLinear from tianshou.utils.net.discrete import Actor, NoisyLinear
@ -227,14 +225,8 @@ class QRDQN(DQN):
return obs, state return obs, state
class CriticFactoryAtariDQN(CriticFactory): class ActorFactoryAtariPlainDQN(ActorFactory):
def create_module( def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
self,
envs: Environments,
device: TDevice,
use_action: bool,
) -> torch.nn.Module:
assert use_action
return DQN( return DQN(
*envs.get_observation_shape(), *envs.get_observation_shape(),
envs.get_action_shape(), envs.get_action_shape(),
@ -243,17 +235,18 @@ class CriticFactoryAtariDQN(CriticFactory):
class ActorFactoryAtariDQN(ActorFactory): class ActorFactoryAtariDQN(ActorFactory):
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool): def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.scale_obs = scale_obs self.scale_obs = scale_obs
self.features_only = features_only
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> Actor:
net_cls = scale_obs(DQN) if self.scale_obs else DQN net_cls = scale_obs(DQN) if self.scale_obs else DQN
net = net_cls( net = net_cls(
*envs.get_observation_shape(), *envs.get_observation_shape(),
envs.get_action_shape(), envs.get_action_shape(),
device=device, device=device,
features_only=True, features_only=self.features_only,
output_dim=self.hidden_size, output_dim=self.hidden_size,
layer_init=layer_init, layer_init=layer_init,
) )

View File

@ -96,7 +96,7 @@ def main(
else None, else None,
), ),
) )
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs)) .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
.with_critic_factory_use_actor() .with_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task)) .with_trainer_stop_callback(AtariStopCallback(task))
) )

View File

@ -44,7 +44,7 @@ from tianshou.policy import (
) )
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net import continuous, discrete from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic, BaseActor
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_MODEL = "model"
@ -285,6 +285,10 @@ class _ActorCriticMixin:
raise ValueError( raise ValueError(
"The options critic_use_actor_module and critic_use_action are mutually exclusive", "The options critic_use_actor_module and critic_use_action are mutually exclusive",
) )
if not isinstance(actor, BaseActor):
raise ValueError(
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
)
if envs.get_type().is_discrete(): if envs.get_type().is_discrete():
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device) critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
elif envs.get_type().is_continuous(): elif envs.get_type().is_continuous():
@ -444,17 +448,17 @@ class DQNAgentFactory(OffpolicyAgentFactory):
self, self,
params: DQNParams, params: DQNParams,
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
critic_factory: CriticFactory, actor_factory: ActorFactory,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
): ):
super().__init__(sampling_config, optim_factory) super().__init__(sampling_config, optim_factory)
self.params = params self.params = params
self.critic_factory = critic_factory self.actor_factory = actor_factory
self.optim_factory = optim_factory self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
critic = self.critic_factory.create_module(envs, device, use_action=True) model = self.actor_factory.create_module(envs, device)
optim = self.optim_factory.create_optimizer(critic, self.params.lr) optim = self.optim_factory.create_optimizer(model, self.params.lr)
kwargs = self.params.create_kwargs( kwargs = self.params.create_kwargs(
ParamTransformerData( ParamTransformerData(
envs=envs, envs=envs,
@ -467,7 +471,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
# noinspection PyTypeChecker # noinspection PyTypeChecker
action_space: gymnasium.spaces.Discrete = envs.get_action_space() action_space: gymnasium.spaces.Discrete = envs.get_action_space()
return DQNPolicy( return DQNPolicy(
model=critic, model=model,
optim=optim, optim=optim,
action_space=action_space, action_space=action_space,
observation_space=envs.get_observation_space(), observation_space=envs.get_observation_space(),

View File

@ -461,7 +461,7 @@ class PPOExperimentBuilder(
class DQNExperimentBuilder( class DQNExperimentBuilder(
RLExperimentBuilder, RLExperimentBuilder,
_BuilderMixinSingleCriticFactory, _BuilderMixinActorFactory,
): ):
def __init__( def __init__(
self, self,
@ -470,7 +470,7 @@ class DQNExperimentBuilder(
sampling_config: RLSamplingConfig, sampling_config: RLSamplingConfig,
): ):
super().__init__(experiment_config, env_factory, sampling_config) super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinSingleCriticFactory.__init__(self) _BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
self._params: DQNParams = DQNParams() self._params: DQNParams = DQNParams()
def with_dqn_params(self, params: DQNParams) -> Self: def with_dqn_params(self, params: DQNParams) -> Self:
@ -482,7 +482,7 @@ class DQNExperimentBuilder(
return DQNAgentFactory( return DQNAgentFactory(
self._params, self._params,
self._sampling_config, self._sampling_config,
self._get_critic_factory(0), self._get_actor_factory(),
self._get_optim_factory(), self._get_optim_factory(),
) )

View File

@ -14,11 +14,12 @@ from tianshou.utils.string import ToStringMixin
class ContinuousActorType: class ContinuousActorType:
GAUSSIAN = "gaussian" GAUSSIAN = "gaussian"
DETERMINISTIC = "deterministic" DETERMINISTIC = "deterministic"
UNSUPPORTED = "unsupported"
class ActorFactory(ToStringMixin, ABC): class ActorFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
pass pass
@staticmethod @staticmethod
@ -67,11 +68,14 @@ class ActorFactoryDefault(ActorFactory):
) )
case ContinuousActorType.DETERMINISTIC: case ContinuousActorType.DETERMINISTIC:
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes) factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
case ContinuousActorType.UNSUPPORTED:
raise ValueError("Continuous action spaces are not supported by the algorithm")
case _: case _:
raise ValueError(self.continuous_actor_type) raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device) return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE: elif env_type == EnvType.DISCRETE:
raise NotImplementedError factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES)
return factory.create_module(envs, device)
else: else:
raise ValueError(f"{env_type} not supported") raise ValueError(f"{env_type} not supported")

View File

@ -5,7 +5,7 @@ from torch import nn
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -30,16 +30,13 @@ class CriticFactoryDefault(CriticFactory):
factory = CriticFactoryContinuousNet(self.hidden_sizes) factory = CriticFactoryContinuousNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action) return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE: elif env_type == EnvType.DISCRETE:
raise NotImplementedError factory = CriticFactoryDiscreteNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
else: else:
raise ValueError(f"{env_type} not supported") raise ValueError(f"{env_type} not supported")
class CriticFactoryContinuous(CriticFactory, ABC): class CriticFactoryContinuousNet(CriticFactory):
pass
class CriticFactoryContinuousNet(CriticFactoryContinuous):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
@ -56,3 +53,22 @@ class CriticFactoryContinuousNet(CriticFactoryContinuous):
critic = continuous.Critic(net_c, device=device).to(device) critic = continuous.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic) init_linear_orthogonal(critic)
return critic return critic
class CriticFactoryDiscreteNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = discrete.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic