Make mypy happy

This commit is contained in:
Dominik Jain 2023-10-13 12:25:28 +02:00
parent 76e870207d
commit 023b33c917
11 changed files with 85 additions and 51 deletions

View File

@ -1,4 +1,5 @@
import logging
import typing
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar, cast
@ -63,11 +64,17 @@ from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params)
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
TActorCriticParams = TypeVar(
"TActorCriticParams",
bound=Params | ParamsMixinLearningRateWithScheduler,
)
TActorDualCriticsParams = TypeVar(
"TActorDualCriticsParams",
bound=Params | ParamsMixinActorAndDualCritics,
)
TDiscreteCriticOnlyParams = TypeVar(
"TDiscreteCriticOnlyParams",
bound=ParamsMixinLearningRateWithScheduler,
bound=Params | ParamsMixinLearningRateWithScheduler,
)
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
log = logging.getLogger(__name__)
@ -321,6 +328,7 @@ class ActorCriticAgentFactory(
optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim)
@typing.no_type_check
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
kwargs = self.params.create_kwargs(
@ -382,6 +390,7 @@ class DiscreteCriticOnlyAgentFactory(
def _get_policy_class(self) -> type[TPolicy]:
pass
@typing.no_type_check
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
model = self.model_factory.create_module(envs, device)
optim = self.optim_factory.create_optimizer(model, self.params.lr)
@ -548,6 +557,7 @@ class ActorDualCriticsAgentFactory(
def _get_critic_use_action(envs: Environments) -> bool:
return envs.get_type().is_continuous()
@typing.no_type_check
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
actor = self.actor_factory.create_module_opt(
envs,

View File

@ -37,7 +37,7 @@ class Environments(ToStringMixin, ABC):
self.env = env
self.train_envs = train_envs
self.test_envs = test_envs
self.persistence = []
self.persistence: Sequence[Persistence] = []
def _tostring_includes(self) -> list[str]:
return []
@ -51,7 +51,7 @@ class Environments(ToStringMixin, ABC):
"state_shape": self.get_observation_shape(),
}
def set_persistence(self, *p: Persistence):
def set_persistence(self, *p: Persistence) -> None:
self.persistence = p
@abstractmethod

View File

@ -27,7 +27,7 @@ from tianshou.highlevel.agent import (
)
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger
from tianshou.highlevel.module.actor import (
ActorFactory,
ActorFactoryDefault,
@ -142,13 +142,18 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
self.env_config = env_config
@classmethod
def from_directory(cls, directory: str) -> Self:
def from_directory(cls, directory: str, restore_policy: bool = True) -> Self:
"""Restores an experiment from a previously stored pickle.
:param directory: persistence directory of a previous run, in which a pickled experiment is found
:param restore_policy: whether the experiment shall be configured to restore the policy that was
persisted in the given directory
"""
with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f:
return pickle.load(f)
experiment: Experiment = pickle.load(f)
if restore_policy:
experiment.config.policy_restore_directory = directory
return experiment
def _set_seed(self) -> None:
seed = self.config.seed
@ -159,7 +164,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
def _build_config_dict(self) -> dict:
return {"experiment": self.pprints()}
def save(self, directory: str):
def save(self, directory: str) -> None:
path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME)
log.info(
f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')",
@ -187,7 +192,8 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
os.makedirs(persistence_dir, exist_ok=True)
with logging.FileLoggerContext(
os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence,
os.path.join(persistence_dir, self.LOG_FILENAME),
enabled=use_persistence,
):
# log initial information
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
@ -209,6 +215,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
# initialize logger
full_config = self._build_config_dict()
full_config.update(envs.info())
logger: TLogger
if use_persistence:
logger = self.logger_factory.create_logger(
log_dir=persistence_dir,
@ -460,7 +467,7 @@ class _BuilderMixinCriticsFactory:
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None:
def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
super().__init__(1, actor_future_provider)
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
@ -553,7 +560,7 @@ class _BuilderMixinCriticEnsembleFactory:
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
return self
def _get_critic_ensemble_factory(self):
def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory:
if self.critic_ensemble_factory is None:
return CriticEnsembleFactoryDefault()
else:
@ -745,8 +752,10 @@ class IQNExperimentBuilder(ExperimentBuilder):
):
super().__init__(env_factory, experiment_config, sampling_config)
self._params: IQNParams = IQNParams()
self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
self._preprocess_network_factory: IntermediateModuleFactory = (
IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
)
)
def with_iqn_params(self, params: IQNParams) -> Self:

View File

@ -4,16 +4,20 @@ from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = TensorboardLogger | WandbLogger
TLogger: TypeAlias = BaseLogger
class LoggerFactory(ToStringMixin, ABC):
@abstractmethod
def create_logger(
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger:
""":param log_dir: path to the directory in which log data is to be stored
:param experiment_name: the name of the job, which may contain os.path.sep
@ -35,7 +39,11 @@ class DefaultLoggerFactory(LoggerFactory):
self.wandb_project = wandb_project
def create_logger(
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger:
writer = SummaryWriter(log_dir)
writer.add_text(
@ -48,18 +56,18 @@ class DefaultLoggerFactory(LoggerFactory):
),
),
)
logger: TLogger
if self.logger_type == "wandb":
logger = WandbLogger(
save_interval=1,
name=experiment_name.replace(os.path.sep, "__"),
run_id=run_id,
config=config_dict,
project=self.wandb_project,
)
logger.load(writer)
elif self.logger_type == "tensorboard":
logger = TensorboardLogger(writer)
else:
raise ValueError(f"Unknown logger type '{self.logger_type}'")
return logger
match self.logger_type:
case "wandb":
wandb_logger = WandbLogger(
save_interval=1,
name=experiment_name.replace(os.path.sep, "__"),
run_id=run_id,
config=config_dict,
project=self.wandb_project,
)
wandb_logger.load(writer)
return wandb_logger
case "tensorboard":
return TensorboardLogger(writer)
case _:
raise ValueError(f"Unknown logger type '{self.logger_type}'")

View File

@ -209,16 +209,16 @@ class ActorFactoryTransientStorageDecorator(ActorFactory):
self.actor_factory = actor_factory
self._actor_future = actor_future
def __getstate__(self):
def __getstate__(self) -> dict:
d = dict(self.__dict__)
del d["_actor_future"]
return d
def __setstate__(self, state):
def __setstate__(self, state: dict) -> None:
self.__dict__ = state
self._actor_future = ActorFuture()
def _tostring_excludes(self):
def _tostring_excludes(self) -> list[str]:
return [*super()._tostring_excludes(), "_actor_future"]
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
import numpy as np
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
@ -61,7 +62,7 @@ class CriticFactoryDefault(CriticFactory):
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
discrete_last_size_use_action_shape: bool = False,
) -> nn.Module:
factory: CriticFactory
env_type = envs.get_type()
@ -89,7 +90,7 @@ class CriticFactoryContinuousNet(CriticFactory):
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
discrete_last_size_use_action_shape: bool = False,
) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
@ -114,7 +115,7 @@ class CriticFactoryDiscreteNet(CriticFactory):
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
discrete_last_size_use_action_shape: bool = False,
) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
@ -125,7 +126,9 @@ class CriticFactoryDiscreteNet(CriticFactory):
activation=nn.Tanh,
device=device,
)
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
last_size = (
int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1
)
critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device)
init_linear_orthogonal(critic)
return critic
@ -149,7 +152,7 @@ class CriticFactoryReuseActor(CriticFactory):
envs: Environments,
device: TDevice,
use_action: bool,
discrete_last_size_use_action_shape=False,
discrete_last_size_use_action_shape: bool = False,
) -> nn.Module:
actor = self.actor_future.actor
if not isinstance(actor, BaseActor):
@ -157,7 +160,10 @@ class CriticFactoryReuseActor(CriticFactory):
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
)
if envs.get_type().is_discrete():
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
# TODO get rid of this prod pattern here and elsewhere
last_size = (
int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1
)
return discrete.Critic(
actor.get_preprocess_net(),
device=device,

View File

@ -45,12 +45,12 @@ class Persistence(ABC):
pass
@abstractmethod
def restore(self, event: RestoreEvent, world: World):
def restore(self, event: RestoreEvent, world: World) -> None:
pass
class PersistenceGroup(Persistence):
def __init__(self, *p: Persistence, enabled=True):
def __init__(self, *p: Persistence, enabled: bool = True):
self.items = p
self.enabled = enabled
@ -60,7 +60,7 @@ class PersistenceGroup(Persistence):
for item in self.items:
item.persist(event, world)
def restore(self, event: RestoreEvent, world: World):
def restore(self, event: RestoreEvent, world: World) -> None:
for item in self.items:
item.restore(event, world)
@ -68,7 +68,7 @@ class PersistenceGroup(Persistence):
class PolicyPersistence:
FILENAME = "policy.dat"
def __init__(self, additional_persistence: Persistence | None = None, enabled=True):
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
""":param additional_persistence: a persistence instance which is to be envoked whenever
this object is used to persist/restore data
:param enabled: whether persistence is enabled (restoration is always enabled)
@ -93,7 +93,7 @@ class PolicyPersistence:
if self.additional_persistence is not None:
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]:
def get_save_best_fn(self, world: World) -> Callable[[torch.nn.Module], None]:
def save_best_fn(pol: torch.nn.Module) -> None:
self.persist(pol, world)

View File

@ -18,7 +18,7 @@ class World:
test_collector: "Collector"
logger: "TLogger"
persist_directory: str
restore_directory: str
restore_directory: str | None
trainer: Optional["BaseTrainer"] = None
def persist_path(self, filename: str) -> str:

View File

@ -17,7 +17,8 @@ from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils import RunningMeanStd
TDistributionFunction: TypeAlias = Callable[[torch.Tensor, ...], torch.distributions.Distribution]
# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution]
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution]
class PGPolicy(BasePolicy):

View File

@ -216,7 +216,7 @@ class ActorProb(BaseActor):
def get_preprocess_net(self) -> nn.Module:
return self.preprocess
def get_output_dim(self):
def get_output_dim(self) -> int:
return self.output_dim
def forward(

View File

@ -189,7 +189,7 @@ class ImplicitQuantileNetwork(Critic):
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
action_shape: Sequence[int] | int,
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
preprocess_net_output_dim: int | None = None,