diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index f9e493b..6cc43d9 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -1,4 +1,5 @@ import logging +import typing from abc import ABC, abstractmethod from typing import Any, Generic, TypeVar, cast @@ -63,11 +64,17 @@ from tianshou.utils.string import ToStringMixin CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" TParams = TypeVar("TParams", bound=Params) -TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler) -TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics) +TActorCriticParams = TypeVar( + "TActorCriticParams", + bound=Params | ParamsMixinLearningRateWithScheduler, +) +TActorDualCriticsParams = TypeVar( + "TActorDualCriticsParams", + bound=Params | ParamsMixinActorAndDualCritics, +) TDiscreteCriticOnlyParams = TypeVar( "TDiscreteCriticOnlyParams", - bound=ParamsMixinLearningRateWithScheduler, + bound=Params | ParamsMixinLearningRateWithScheduler, ) TPolicy = TypeVar("TPolicy", bound=BasePolicy) log = logging.getLogger(__name__) @@ -321,6 +328,7 @@ class ActorCriticAgentFactory( optim = self.optim_factory.create_optimizer(actor_critic, lr) return ActorCriticModuleOpt(actor_critic, optim) + @typing.no_type_check def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) kwargs = self.params.create_kwargs( @@ -382,6 +390,7 @@ class DiscreteCriticOnlyAgentFactory( def _get_policy_class(self) -> type[TPolicy]: pass + @typing.no_type_check def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: model = self.model_factory.create_module(envs, device) optim = self.optim_factory.create_optimizer(model, self.params.lr) @@ -548,6 +557,7 @@ class ActorDualCriticsAgentFactory( def _get_critic_use_action(envs: Environments) -> bool: return envs.get_type().is_continuous() + @typing.no_type_check def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: actor = self.actor_factory.create_module_opt( envs, diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 179f941..f356f4d 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -37,7 +37,7 @@ class Environments(ToStringMixin, ABC): self.env = env self.train_envs = train_envs self.test_envs = test_envs - self.persistence = [] + self.persistence: Sequence[Persistence] = [] def _tostring_includes(self) -> list[str]: return [] @@ -51,7 +51,7 @@ class Environments(ToStringMixin, ABC): "state_shape": self.get_observation_shape(), } - def set_persistence(self, *p: Persistence): + def set_persistence(self, *p: Persistence) -> None: self.persistence = p @abstractmethod diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 8394111..a5c72f0 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -27,7 +27,7 @@ from tianshou.highlevel.agent import ( ) from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import EnvFactory, Environments -from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory +from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger from tianshou.highlevel.module.actor import ( ActorFactory, ActorFactoryDefault, @@ -142,13 +142,18 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): self.env_config = env_config @classmethod - def from_directory(cls, directory: str) -> Self: + def from_directory(cls, directory: str, restore_policy: bool = True) -> Self: """Restores an experiment from a previously stored pickle. :param directory: persistence directory of a previous run, in which a pickled experiment is found + :param restore_policy: whether the experiment shall be configured to restore the policy that was + persisted in the given directory """ with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f: - return pickle.load(f) + experiment: Experiment = pickle.load(f) + if restore_policy: + experiment.config.policy_restore_directory = directory + return experiment def _set_seed(self) -> None: seed = self.config.seed @@ -159,7 +164,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): def _build_config_dict(self) -> dict: return {"experiment": self.pprints()} - def save(self, directory: str): + def save(self, directory: str) -> None: path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME) log.info( f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')", @@ -187,7 +192,8 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): os.makedirs(persistence_dir, exist_ok=True) with logging.FileLoggerContext( - os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence, + os.path.join(persistence_dir, self.LOG_FILENAME), + enabled=use_persistence, ): # log initial information log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}") @@ -209,6 +215,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): # initialize logger full_config = self._build_config_dict() full_config.update(envs.info()) + logger: TLogger if use_persistence: logger = self.logger_factory.create_logger( log_dir=persistence_dir, @@ -460,7 +467,7 @@ class _BuilderMixinCriticsFactory: class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): - def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None: + def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None: super().__init__(1, actor_future_provider) def with_critic_factory(self, critic_factory: CriticFactory) -> Self: @@ -553,7 +560,7 @@ class _BuilderMixinCriticEnsembleFactory: self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes) return self - def _get_critic_ensemble_factory(self): + def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory: if self.critic_ensemble_factory is None: return CriticEnsembleFactoryDefault() else: @@ -745,8 +752,10 @@ class IQNExperimentBuilder(ExperimentBuilder): ): super().__init__(env_factory, experiment_config, sampling_config) self._params: IQNParams = IQNParams() - self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory( - ActorFactoryDefault(ContinuousActorType.UNSUPPORTED), + self._preprocess_network_factory: IntermediateModuleFactory = ( + IntermediateModuleFactoryFromActorFactory( + ActorFactoryDefault(ContinuousActorType.UNSUPPORTED), + ) ) def with_iqn_params(self, params: IQNParams) -> Self: diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 6260977..c10599a 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -4,16 +4,20 @@ from typing import Literal, TypeAlias from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger from tianshou.utils.string import ToStringMixin -TLogger: TypeAlias = TensorboardLogger | WandbLogger +TLogger: TypeAlias = BaseLogger class LoggerFactory(ToStringMixin, ABC): @abstractmethod def create_logger( - self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict, + self, + log_dir: str, + experiment_name: str, + run_id: str | None, + config_dict: dict, ) -> TLogger: """:param log_dir: path to the directory in which log data is to be stored :param experiment_name: the name of the job, which may contain os.path.sep @@ -35,7 +39,11 @@ class DefaultLoggerFactory(LoggerFactory): self.wandb_project = wandb_project def create_logger( - self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict, + self, + log_dir: str, + experiment_name: str, + run_id: str | None, + config_dict: dict, ) -> TLogger: writer = SummaryWriter(log_dir) writer.add_text( @@ -48,18 +56,18 @@ class DefaultLoggerFactory(LoggerFactory): ), ), ) - logger: TLogger - if self.logger_type == "wandb": - logger = WandbLogger( - save_interval=1, - name=experiment_name.replace(os.path.sep, "__"), - run_id=run_id, - config=config_dict, - project=self.wandb_project, - ) - logger.load(writer) - elif self.logger_type == "tensorboard": - logger = TensorboardLogger(writer) - else: - raise ValueError(f"Unknown logger type '{self.logger_type}'") - return logger + match self.logger_type: + case "wandb": + wandb_logger = WandbLogger( + save_interval=1, + name=experiment_name.replace(os.path.sep, "__"), + run_id=run_id, + config=config_dict, + project=self.wandb_project, + ) + wandb_logger.load(writer) + return wandb_logger + case "tensorboard": + return TensorboardLogger(writer) + case _: + raise ValueError(f"Unknown logger type '{self.logger_type}'") diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index e02b280..d3f31f6 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -209,16 +209,16 @@ class ActorFactoryTransientStorageDecorator(ActorFactory): self.actor_factory = actor_factory self._actor_future = actor_future - def __getstate__(self): + def __getstate__(self) -> dict: d = dict(self.__dict__) del d["_actor_future"] return d - def __setstate__(self, state): + def __setstate__(self, state: dict) -> None: self.__dict__ = state self._actor_future = ActorFuture() - def _tostring_excludes(self): + def _tostring_excludes(self) -> list[str]: return [*super()._tostring_excludes(), "_actor_future"] def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 4af97bc..54b6003 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +import numpy as np from torch import nn from tianshou.highlevel.env import Environments, EnvType @@ -61,7 +62,7 @@ class CriticFactoryDefault(CriticFactory): envs: Environments, device: TDevice, use_action: bool, - discrete_last_size_use_action_shape=False, + discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: factory: CriticFactory env_type = envs.get_type() @@ -89,7 +90,7 @@ class CriticFactoryContinuousNet(CriticFactory): envs: Environments, device: TDevice, use_action: bool, - discrete_last_size_use_action_shape=False, + discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( @@ -114,7 +115,7 @@ class CriticFactoryDiscreteNet(CriticFactory): envs: Environments, device: TDevice, use_action: bool, - discrete_last_size_use_action_shape=False, + discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( @@ -125,7 +126,9 @@ class CriticFactoryDiscreteNet(CriticFactory): activation=nn.Tanh, device=device, ) - last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1 + last_size = ( + int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 + ) critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device) init_linear_orthogonal(critic) return critic @@ -149,7 +152,7 @@ class CriticFactoryReuseActor(CriticFactory): envs: Environments, device: TDevice, use_action: bool, - discrete_last_size_use_action_shape=False, + discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: actor = self.actor_future.actor if not isinstance(actor, BaseActor): @@ -157,7 +160,10 @@ class CriticFactoryReuseActor(CriticFactory): f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}", ) if envs.get_type().is_discrete(): - last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1 + # TODO get rid of this prod pattern here and elsewhere + last_size = ( + int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 + ) return discrete.Critic( actor.get_preprocess_net(), device=device, diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 225b1e8..fad4bea 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -45,12 +45,12 @@ class Persistence(ABC): pass @abstractmethod - def restore(self, event: RestoreEvent, world: World): + def restore(self, event: RestoreEvent, world: World) -> None: pass class PersistenceGroup(Persistence): - def __init__(self, *p: Persistence, enabled=True): + def __init__(self, *p: Persistence, enabled: bool = True): self.items = p self.enabled = enabled @@ -60,7 +60,7 @@ class PersistenceGroup(Persistence): for item in self.items: item.persist(event, world) - def restore(self, event: RestoreEvent, world: World): + def restore(self, event: RestoreEvent, world: World) -> None: for item in self.items: item.restore(event, world) @@ -68,7 +68,7 @@ class PersistenceGroup(Persistence): class PolicyPersistence: FILENAME = "policy.dat" - def __init__(self, additional_persistence: Persistence | None = None, enabled=True): + def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True): """:param additional_persistence: a persistence instance which is to be envoked whenever this object is used to persist/restore data :param enabled: whether persistence is enabled (restoration is always enabled) @@ -93,7 +93,7 @@ class PolicyPersistence: if self.additional_persistence is not None: self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world) - def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]: + def get_save_best_fn(self, world: World) -> Callable[[torch.nn.Module], None]: def save_best_fn(pol: torch.nn.Module) -> None: self.persist(pol, world) diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index aa278db..6ec7c4b 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -18,7 +18,7 @@ class World: test_collector: "Collector" logger: "TLogger" persist_directory: str - restore_directory: str + restore_directory: str | None trainer: Optional["BaseTrainer"] = None def persist_path(self, filename: str) -> str: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index fb7027b..c9f5ccc 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -17,7 +17,8 @@ from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.utils import RunningMeanStd -TDistributionFunction: TypeAlias = Callable[[torch.Tensor, ...], torch.distributions.Distribution] +# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution] +TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution] class PGPolicy(BasePolicy): diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 60c3569..5595750 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -216,7 +216,7 @@ class ActorProb(BaseActor): def get_preprocess_net(self) -> nn.Module: return self.preprocess - def get_output_dim(self): + def get_output_dim(self) -> int: return self.output_dim def forward( diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index ccdd69d..1b5b019 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -189,7 +189,7 @@ class ImplicitQuantileNetwork(Critic): def __init__( self, preprocess_net: nn.Module, - action_shape: Sequence[int], + action_shape: Sequence[int] | int, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, preprocess_net_output_dim: int | None = None,