Support IQN in high-level API
* Add example atari_iqn_hl * Factor out trainer callbacks to new module atari_callbacks * Extract base class for DQN-based agent factories * Improved module factory interface design, achieving higher generality
This commit is contained in:
parent
213e08a846
commit
a8a367c42d
33
examples/atari/atari_callbacks.py
Normal file
33
examples/atari/atari_callbacks.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from tianshou.highlevel.trainer import (
|
||||||
|
TrainerEpochCallbackTest,
|
||||||
|
TrainerEpochCallbackTrain,
|
||||||
|
TrainingContext,
|
||||||
|
)
|
||||||
|
from tianshou.policy import DQNPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
|
||||||
|
def __init__(self, eps_test: float):
|
||||||
|
self.eps_test = eps_test
|
||||||
|
|
||||||
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
|
policy: DQNPolicy = context.policy
|
||||||
|
policy.set_eps(self.eps_test)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
||||||
|
def __init__(self, eps_train: float, eps_train_final: float):
|
||||||
|
self.eps_train = eps_train
|
||||||
|
self.eps_train_final = eps_train_final
|
||||||
|
|
||||||
|
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||||
|
policy: DQNPolicy = context.policy
|
||||||
|
logger = context.logger.logger
|
||||||
|
# nature DQN setting, linear decay in the first 1M steps
|
||||||
|
if env_step <= 1e6:
|
||||||
|
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
|
||||||
|
else:
|
||||||
|
eps = self.eps_train_final
|
||||||
|
policy.set_eps(eps)
|
||||||
|
if env_step % 1000 == 0:
|
||||||
|
logger.write("train/env_step", env_step, {"train/eps": eps})
|
@ -1,13 +1,12 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
ActorFactoryAtariPlainDQN,
|
ActorFactoryAtariPlainDQN,
|
||||||
FeatureNetFactoryDQN,
|
IntermediateModuleFactoryAtariDQN,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
@ -26,6 +25,7 @@ from tianshou.highlevel.trainer import (
|
|||||||
)
|
)
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -53,8 +53,7 @@ def main(
|
|||||||
icm_reward_scale: float = 0.01,
|
icm_reward_scale: float = 0.01,
|
||||||
icm_forward_loss_weight: float = 0.2,
|
icm_forward_loss_weight: float = 0.2,
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag())
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
|
||||||
|
|
||||||
sampling_config = SamplingConfig(
|
sampling_config = SamplingConfig(
|
||||||
num_epochs=epoch,
|
num_epochs=epoch,
|
||||||
@ -115,7 +114,7 @@ def main(
|
|||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
FeatureNetFactoryDQN(),
|
IntermediateModuleFactoryAtariDQN(net_only=True),
|
||||||
[512],
|
[512],
|
||||||
lr,
|
lr,
|
||||||
icm_lr_scale,
|
icm_lr_scale,
|
||||||
|
105
examples/atari/atari_iqn_hl.py
Normal file
105
examples/atari/atari_iqn_hl.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
|
||||||
|
from examples.atari.atari_callbacks import (
|
||||||
|
TestEpochCallbackDQNSetEps,
|
||||||
|
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
||||||
|
)
|
||||||
|
from examples.atari.atari_network import (
|
||||||
|
IntermediateModuleFactoryAtariDQN,
|
||||||
|
)
|
||||||
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
|
from tianshou.highlevel.experiment import (
|
||||||
|
ExperimentConfig,
|
||||||
|
IQNExperimentBuilder,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.params.policy_params import IQNParams
|
||||||
|
from tianshou.utils import logging
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: ExperimentConfig,
|
||||||
|
task: str = "PongNoFrameskip-v4",
|
||||||
|
scale_obs: int = 0,
|
||||||
|
eps_test: float = 0.005,
|
||||||
|
eps_train: float = 1.0,
|
||||||
|
eps_train_final: float = 0.05,
|
||||||
|
buffer_size: int = 100000,
|
||||||
|
lr: float = 0.0001,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
sample_size: int = 32,
|
||||||
|
online_sample_size: int = 8,
|
||||||
|
target_sample_size: int = 8,
|
||||||
|
num_cosines: int = 64,
|
||||||
|
hidden_sizes: Sequence[int] = (512,),
|
||||||
|
n_step: int = 3,
|
||||||
|
target_update_freq: int = 500,
|
||||||
|
epoch: int = 100,
|
||||||
|
step_per_epoch: int = 100000,
|
||||||
|
step_per_collect: int = 10,
|
||||||
|
update_per_step: float = 0.1,
|
||||||
|
batch_size: int = 32,
|
||||||
|
training_num: int = 10,
|
||||||
|
test_num: int = 10,
|
||||||
|
frames_stack: int = 4,
|
||||||
|
save_buffer_name: str | None = None, # TODO support?
|
||||||
|
):
|
||||||
|
log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag())
|
||||||
|
|
||||||
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=epoch,
|
||||||
|
step_per_epoch=step_per_epoch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_train_envs=training_num,
|
||||||
|
num_test_envs=test_num,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
step_per_collect=step_per_collect,
|
||||||
|
update_per_step=update_per_step,
|
||||||
|
repeat_per_collect=None,
|
||||||
|
replay_buffer_stack_num=frames_stack,
|
||||||
|
replay_buffer_ignore_obs_next=True,
|
||||||
|
replay_buffer_save_only_last_obs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_factory = AtariEnvFactory(
|
||||||
|
task,
|
||||||
|
experiment_config.seed,
|
||||||
|
sampling_config,
|
||||||
|
frames_stack,
|
||||||
|
scale=scale_obs,
|
||||||
|
)
|
||||||
|
|
||||||
|
experiment = (
|
||||||
|
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
.with_iqn_params(
|
||||||
|
IQNParams(
|
||||||
|
discount_factor=gamma,
|
||||||
|
estimation_step=n_step,
|
||||||
|
lr=lr,
|
||||||
|
sample_size=sample_size,
|
||||||
|
online_sample_size=online_sample_size,
|
||||||
|
target_update_freq=target_update_freq,
|
||||||
|
target_sample_size=target_sample_size,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
|
num_cosines=num_cosines,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False))
|
||||||
|
.with_trainer_epoch_callback_train(
|
||||||
|
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||||
|
)
|
||||||
|
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
||||||
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.run_main(lambda: CLI(main))
|
@ -7,7 +7,11 @@ from torch import nn
|
|||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module.actor import ActorFactory
|
from tianshou.highlevel.module.actor import ActorFactory
|
||||||
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
from tianshou.highlevel.module.core import (
|
||||||
|
IntermediateModule,
|
||||||
|
IntermediateModuleFactory,
|
||||||
|
TDevice,
|
||||||
|
)
|
||||||
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||||
|
|
||||||
|
|
||||||
@ -253,12 +257,15 @@ class ActorFactoryAtariDQN(ActorFactory):
|
|||||||
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
||||||
|
|
||||||
|
|
||||||
class FeatureNetFactoryDQN(ModuleFactory):
|
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
def __init__(self, net_only: bool):
|
||||||
|
self.net_only = net_only
|
||||||
|
|
||||||
|
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||||
dqn = DQN(
|
dqn = DQN(
|
||||||
*envs.get_observation_shape(),
|
*envs.get_observation_shape(),
|
||||||
envs.get_action_shape(),
|
envs.get_action_shape(),
|
||||||
device,
|
device=device,
|
||||||
features_only=True,
|
features_only=True,
|
||||||
)
|
)
|
||||||
return Module(dqn.net, dqn.output_dim)
|
return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim)
|
||||||
|
@ -8,7 +8,7 @@ from jsonargparse import CLI
|
|||||||
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
ActorFactoryAtariDQN,
|
ActorFactoryAtariDQN,
|
||||||
FeatureNetFactoryDQN,
|
IntermediateModuleFactoryAtariDQN,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
@ -103,7 +103,7 @@ def main(
|
|||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
FeatureNetFactoryDQN(),
|
IntermediateModuleFactoryAtariDQN(net_only=True),
|
||||||
[hidden_sizes],
|
[hidden_sizes],
|
||||||
lr,
|
lr,
|
||||||
icm_lr_scale,
|
icm_lr_scale,
|
||||||
|
@ -6,7 +6,7 @@ from jsonargparse import CLI
|
|||||||
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
ActorFactoryAtariDQN,
|
ActorFactoryAtariDQN,
|
||||||
FeatureNetFactoryDQN,
|
IntermediateModuleFactoryAtariDQN,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
@ -26,7 +26,7 @@ from tianshou.utils.logging import datetime_tag
|
|||||||
def main(
|
def main(
|
||||||
experiment_config: ExperimentConfig,
|
experiment_config: ExperimentConfig,
|
||||||
task: str = "PongNoFrameskip-v4",
|
task: str = "PongNoFrameskip-v4",
|
||||||
scale_obs: bool = False,
|
scale_obs: int = 0,
|
||||||
buffer_size: int = 100000,
|
buffer_size: int = 100000,
|
||||||
actor_lr: float = 1e-5,
|
actor_lr: float = 1e-5,
|
||||||
critic_lr: float = 1e-5,
|
critic_lr: float = 1e-5,
|
||||||
@ -67,7 +67,9 @@ def main(
|
|||||||
replay_buffer_save_only_last_obs=True,
|
replay_buffer_save_only_last_obs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
env_factory = AtariEnvFactory(
|
||||||
|
task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs,
|
||||||
|
)
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
@ -82,14 +84,14 @@ def main(
|
|||||||
estimation_step=n_step,
|
estimation_step=n_step,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True))
|
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
|
||||||
.with_common_critic_factory_use_actor()
|
.with_common_critic_factory_use_actor()
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
)
|
)
|
||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
FeatureNetFactoryDQN(),
|
IntermediateModuleFactoryAtariDQN(net_only=True),
|
||||||
[hidden_size],
|
[hidden_size],
|
||||||
actor_lr,
|
actor_lr,
|
||||||
icm_lr_scale,
|
icm_lr_scale,
|
||||||
|
@ -13,7 +13,10 @@ from tianshou.highlevel.logger import Logger
|
|||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.core import TDevice
|
from tianshou.highlevel.module.core import (
|
||||||
|
ModuleFactory,
|
||||||
|
TDevice,
|
||||||
|
)
|
||||||
from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
|
from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
|
||||||
from tianshou.highlevel.module.module_opt import (
|
from tianshou.highlevel.module.module_opt import (
|
||||||
ActorCriticModuleOpt,
|
ActorCriticModuleOpt,
|
||||||
@ -24,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
DDPGParams,
|
DDPGParams,
|
||||||
DiscreteSACParams,
|
DiscreteSACParams,
|
||||||
DQNParams,
|
DQNParams,
|
||||||
|
IQNParams,
|
||||||
NPGParams,
|
NPGParams,
|
||||||
Params,
|
Params,
|
||||||
ParamsMixinActorAndDualCritics,
|
ParamsMixinActorAndDualCritics,
|
||||||
@ -44,6 +48,7 @@ from tianshou.policy import (
|
|||||||
DDPGPolicy,
|
DDPGPolicy,
|
||||||
DiscreteSACPolicy,
|
DiscreteSACPolicy,
|
||||||
DQNPolicy,
|
DQNPolicy,
|
||||||
|
IQNPolicy,
|
||||||
NPGPolicy,
|
NPGPolicy,
|
||||||
PGPolicy,
|
PGPolicy,
|
||||||
PPOPolicy,
|
PPOPolicy,
|
||||||
@ -61,6 +66,9 @@ CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
|||||||
TParams = TypeVar("TParams", bound=Params)
|
TParams = TypeVar("TParams", bound=Params)
|
||||||
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
||||||
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
||||||
|
TDiscreteCriticOnlyParams = TypeVar(
|
||||||
|
"TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler,
|
||||||
|
)
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
|
||||||
|
|
||||||
@ -394,21 +402,27 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
|||||||
return TRPOPolicy
|
return TRPOPolicy
|
||||||
|
|
||||||
|
|
||||||
class DQNAgentFactory(OffpolicyAgentFactory):
|
class DiscreteCriticOnlyAgentFactory(
|
||||||
|
OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy],
|
||||||
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: DQNParams,
|
params: TDiscreteCriticOnlyParams,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
model_factory: ModuleFactory,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config, optim_factory)
|
super().__init__(sampling_config, optim_factory)
|
||||||
self.params = params
|
self.params = params
|
||||||
self.actor_factory = actor_factory
|
self.model_factory = model_factory
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
@abstractmethod
|
||||||
model = self.actor_factory.create_module(envs, device)
|
def _get_policy_class(self) -> type[TPolicy]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||||
|
model = self.model_factory.create_module(envs, device)
|
||||||
optim = self.optim_factory.create_optimizer(model, self.params.lr)
|
optim = self.optim_factory.create_optimizer(model, self.params.lr)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerData(
|
ParamTransformerData(
|
||||||
@ -420,7 +434,8 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
|||||||
)
|
)
|
||||||
envs.get_type().assert_discrete(self)
|
envs.get_type().assert_discrete(self)
|
||||||
action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space())
|
action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space())
|
||||||
return DQNPolicy(
|
policy_class = self._get_policy_class()
|
||||||
|
return policy_class(
|
||||||
model=model,
|
model=model,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
action_space=action_space,
|
action_space=action_space,
|
||||||
@ -429,6 +444,16 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DQNPolicy]):
|
||||||
|
def _get_policy_class(self) -> type[DQNPolicy]:
|
||||||
|
return DQNPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]):
|
||||||
|
def _get_policy_class(self) -> type[IQNPolicy]:
|
||||||
|
return IQNPolicy
|
||||||
|
|
||||||
|
|
||||||
class DDPGAgentFactory(OffpolicyAgentFactory):
|
class DDPGAgentFactory(OffpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -15,6 +15,7 @@ from tianshou.highlevel.agent import (
|
|||||||
DDPGAgentFactory,
|
DDPGAgentFactory,
|
||||||
DiscreteSACAgentFactory,
|
DiscreteSACAgentFactory,
|
||||||
DQNAgentFactory,
|
DQNAgentFactory,
|
||||||
|
IQNAgentFactory,
|
||||||
NPGAgentFactory,
|
NPGAgentFactory,
|
||||||
PGAgentFactory,
|
PGAgentFactory,
|
||||||
PPOAgentFactory,
|
PPOAgentFactory,
|
||||||
@ -33,6 +34,11 @@ from tianshou.highlevel.module.actor import (
|
|||||||
ActorFuture,
|
ActorFuture,
|
||||||
ActorFutureProviderProtocol,
|
ActorFutureProviderProtocol,
|
||||||
ContinuousActorType,
|
ContinuousActorType,
|
||||||
|
IntermediateModuleFactoryFromActorFactory,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.module.core import (
|
||||||
|
ImplicitQuantileNetworkFactory,
|
||||||
|
IntermediateModuleFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.critic import (
|
from tianshou.highlevel.module.critic import (
|
||||||
CriticEnsembleFactory,
|
CriticEnsembleFactory,
|
||||||
@ -47,6 +53,7 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
DDPGParams,
|
DDPGParams,
|
||||||
DiscreteSACParams,
|
DiscreteSACParams,
|
||||||
DQNParams,
|
DQNParams,
|
||||||
|
IQNParams,
|
||||||
NPGParams,
|
NPGParams,
|
||||||
PGParams,
|
PGParams,
|
||||||
PPOParams,
|
PPOParams,
|
||||||
@ -641,6 +648,41 @@ class DQNExperimentBuilder(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IQNExperimentBuilder(ExperimentBuilder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
experiment_config: ExperimentConfig | None = None,
|
||||||
|
sampling_config: SamplingConfig | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
|
self._params: IQNParams = IQNParams()
|
||||||
|
self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory(
|
||||||
|
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_iqn_params(self, params: IQNParams) -> Self:
|
||||||
|
self._params = params
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_preprocess_network_factory(self, module_factory: IntermediateModuleFactory) -> Self:
|
||||||
|
self._preprocess_network_factory = module_factory
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
|
model_factory = ImplicitQuantileNetworkFactory(
|
||||||
|
self._preprocess_network_factory,
|
||||||
|
hidden_sizes=self._params.hidden_sizes,
|
||||||
|
num_cosines=self._params.num_cosines,
|
||||||
|
)
|
||||||
|
return IQNAgentFactory(
|
||||||
|
self._params,
|
||||||
|
self._sampling_config,
|
||||||
|
model_factory,
|
||||||
|
self._get_optim_factory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DDPGExperimentBuilder(
|
class DDPGExperimentBuilder(
|
||||||
ExperimentBuilder,
|
ExperimentBuilder,
|
||||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||||
|
@ -8,7 +8,13 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
from tianshou.highlevel.module.core import (
|
||||||
|
IntermediateModule,
|
||||||
|
IntermediateModuleFactory,
|
||||||
|
ModuleFactory,
|
||||||
|
TDevice,
|
||||||
|
init_linear_orthogonal,
|
||||||
|
)
|
||||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
@ -34,7 +40,7 @@ class ActorFutureProviderProtocol(Protocol):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ActorFactory(ToStringMixin, ABC):
|
class ActorFactory(ModuleFactory, ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||||
pass
|
pass
|
||||||
@ -212,3 +218,13 @@ class ActorFactoryTransientStorageDecorator(ActorFactory):
|
|||||||
module = self.actor_factory.create_module(envs, device)
|
module = self.actor_factory.create_module(envs, device)
|
||||||
self._actor_future.actor = module
|
self._actor_future.actor = module
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory):
|
||||||
|
def __init__(self, actor_factory: ActorFactory):
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
|
||||||
|
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||||
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
|
assert isinstance(actor, BaseActor)
|
||||||
|
return IntermediateModule(actor, actor.get_output_dim())
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TDevice: TypeAlias = str | torch.device
|
TDevice: TypeAlias = str | torch.device
|
||||||
@ -24,24 +24,45 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None:
|
|||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Module:
|
class IntermediateModule:
|
||||||
module: torch.nn.Module
|
module: torch.nn.Module
|
||||||
output_dim: int
|
output_dim: int
|
||||||
|
|
||||||
|
|
||||||
class ModuleFactory(ToStringMixin, ABC):
|
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||||
|
return self.create_intermediate_module(envs, device).module
|
||||||
|
|
||||||
class ModuleFactoryNet(
|
|
||||||
ModuleFactory,
|
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
|
||||||
): # TODO This is unused and broken; use it in ActorFactory* and so on?
|
def __init__(
|
||||||
def __init__(self, hidden_sizes: int | Sequence[int]):
|
self,
|
||||||
|
preprocess_net_factory: IntermediateModuleFactory,
|
||||||
|
hidden_sizes: Sequence[int] = (),
|
||||||
|
num_cosines: int = 64,
|
||||||
|
):
|
||||||
|
self.preprocess_net_factory = preprocess_net_factory
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
self.num_cosines = num_cosines
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
|
||||||
module = Net(envs.get_observation_shape())
|
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
|
||||||
return Module(module, module.output_dim)
|
return ImplicitQuantileNetwork(
|
||||||
|
preprocess_net=preprocess_net.module,
|
||||||
|
action_shape=envs.get_action_shape(),
|
||||||
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
num_cosines=self.num_cosines,
|
||||||
|
preprocess_net_output_dim=preprocess_net.output_dim,
|
||||||
|
device=device,
|
||||||
|
).to(device)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Any, Literal, Protocol
|
from typing import Any, Literal, Protocol
|
||||||
|
|
||||||
@ -388,6 +389,23 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
|||||||
return transformers
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IQNParams(DQNParams):
|
||||||
|
sample_size: int = 32
|
||||||
|
online_sample_size: int = 8
|
||||||
|
target_sample_size: int = 8
|
||||||
|
num_quantiles: int = 200
|
||||||
|
hidden_sizes: Sequence[int] = ()
|
||||||
|
"""hidden dimensions to use in the IQN network"""
|
||||||
|
num_cosines: int = 64
|
||||||
|
"""number of cosines to use in the IQN network"""
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines"))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DDPGParams(Params, ParamsMixinActorAndCritic):
|
class DDPGParams(Params, ParamsMixinActorAndCritic):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
|
@ -3,7 +3,7 @@ from collections.abc import Sequence
|
|||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.policy import BasePolicy, ICMPolicy
|
from tianshou.policy import BasePolicy, ICMPolicy
|
||||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||||
@ -31,7 +31,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_net_factory: ModuleFactory,
|
feature_net_factory: IntermediateModuleFactory,
|
||||||
hidden_sizes: Sequence[int],
|
hidden_sizes: Sequence[int],
|
||||||
lr: float,
|
lr: float,
|
||||||
lr_scale: float,
|
lr_scale: float,
|
||||||
@ -52,7 +52,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
|||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
device: TDevice,
|
device: TDevice,
|
||||||
) -> ICMPolicy:
|
) -> ICMPolicy:
|
||||||
feature_net = self.feature_net_factory.create_module(envs, device)
|
feature_net = self.feature_net_factory.create_intermediate_module(envs, device)
|
||||||
action_dim = envs.get_action_shape()
|
action_dim = envs.get_action_shape()
|
||||||
if not isinstance(action_dim, int):
|
if not isinstance(action_dim, int):
|
||||||
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
|
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user