Make mypy happy
This commit is contained in:
parent
76e870207d
commit
023b33c917
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import typing
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Generic, TypeVar, cast
|
from typing import Any, Generic, TypeVar, cast
|
||||||
|
|
||||||
@ -63,11 +64,17 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||||
TParams = TypeVar("TParams", bound=Params)
|
TParams = TypeVar("TParams", bound=Params)
|
||||||
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
TActorCriticParams = TypeVar(
|
||||||
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
"TActorCriticParams",
|
||||||
|
bound=Params | ParamsMixinLearningRateWithScheduler,
|
||||||
|
)
|
||||||
|
TActorDualCriticsParams = TypeVar(
|
||||||
|
"TActorDualCriticsParams",
|
||||||
|
bound=Params | ParamsMixinActorAndDualCritics,
|
||||||
|
)
|
||||||
TDiscreteCriticOnlyParams = TypeVar(
|
TDiscreteCriticOnlyParams = TypeVar(
|
||||||
"TDiscreteCriticOnlyParams",
|
"TDiscreteCriticOnlyParams",
|
||||||
bound=ParamsMixinLearningRateWithScheduler,
|
bound=Params | ParamsMixinLearningRateWithScheduler,
|
||||||
)
|
)
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -321,6 +328,7 @@ class ActorCriticAgentFactory(
|
|||||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||||
return ActorCriticModuleOpt(actor_critic, optim)
|
return ActorCriticModuleOpt(actor_critic, optim)
|
||||||
|
|
||||||
|
@typing.no_type_check
|
||||||
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
|
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
|
||||||
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
@ -382,6 +390,7 @@ class DiscreteCriticOnlyAgentFactory(
|
|||||||
def _get_policy_class(self) -> type[TPolicy]:
|
def _get_policy_class(self) -> type[TPolicy]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@typing.no_type_check
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||||
model = self.model_factory.create_module(envs, device)
|
model = self.model_factory.create_module(envs, device)
|
||||||
optim = self.optim_factory.create_optimizer(model, self.params.lr)
|
optim = self.optim_factory.create_optimizer(model, self.params.lr)
|
||||||
@ -548,6 +557,7 @@ class ActorDualCriticsAgentFactory(
|
|||||||
def _get_critic_use_action(envs: Environments) -> bool:
|
def _get_critic_use_action(envs: Environments) -> bool:
|
||||||
return envs.get_type().is_continuous()
|
return envs.get_type().is_continuous()
|
||||||
|
|
||||||
|
@typing.no_type_check
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||||
actor = self.actor_factory.create_module_opt(
|
actor = self.actor_factory.create_module_opt(
|
||||||
envs,
|
envs,
|
||||||
|
@ -37,7 +37,7 @@ class Environments(ToStringMixin, ABC):
|
|||||||
self.env = env
|
self.env = env
|
||||||
self.train_envs = train_envs
|
self.train_envs = train_envs
|
||||||
self.test_envs = test_envs
|
self.test_envs = test_envs
|
||||||
self.persistence = []
|
self.persistence: Sequence[Persistence] = []
|
||||||
|
|
||||||
def _tostring_includes(self) -> list[str]:
|
def _tostring_includes(self) -> list[str]:
|
||||||
return []
|
return []
|
||||||
@ -51,7 +51,7 @@ class Environments(ToStringMixin, ABC):
|
|||||||
"state_shape": self.get_observation_shape(),
|
"state_shape": self.get_observation_shape(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_persistence(self, *p: Persistence):
|
def set_persistence(self, *p: Persistence) -> None:
|
||||||
self.persistence = p
|
self.persistence = p
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -27,7 +27,7 @@ from tianshou.highlevel.agent import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory, Environments
|
from tianshou.highlevel.env import EnvFactory, Environments
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
ActorFactoryDefault,
|
ActorFactoryDefault,
|
||||||
@ -142,13 +142,18 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
self.env_config = env_config
|
self.env_config = env_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_directory(cls, directory: str) -> Self:
|
def from_directory(cls, directory: str, restore_policy: bool = True) -> Self:
|
||||||
"""Restores an experiment from a previously stored pickle.
|
"""Restores an experiment from a previously stored pickle.
|
||||||
|
|
||||||
:param directory: persistence directory of a previous run, in which a pickled experiment is found
|
:param directory: persistence directory of a previous run, in which a pickled experiment is found
|
||||||
|
:param restore_policy: whether the experiment shall be configured to restore the policy that was
|
||||||
|
persisted in the given directory
|
||||||
"""
|
"""
|
||||||
with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f:
|
with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f:
|
||||||
return pickle.load(f)
|
experiment: Experiment = pickle.load(f)
|
||||||
|
if restore_policy:
|
||||||
|
experiment.config.policy_restore_directory = directory
|
||||||
|
return experiment
|
||||||
|
|
||||||
def _set_seed(self) -> None:
|
def _set_seed(self) -> None:
|
||||||
seed = self.config.seed
|
seed = self.config.seed
|
||||||
@ -159,7 +164,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
def _build_config_dict(self) -> dict:
|
def _build_config_dict(self) -> dict:
|
||||||
return {"experiment": self.pprints()}
|
return {"experiment": self.pprints()}
|
||||||
|
|
||||||
def save(self, directory: str):
|
def save(self, directory: str) -> None:
|
||||||
path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME)
|
path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME)
|
||||||
log.info(
|
log.info(
|
||||||
f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')",
|
f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')",
|
||||||
@ -187,7 +192,8 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
os.makedirs(persistence_dir, exist_ok=True)
|
os.makedirs(persistence_dir, exist_ok=True)
|
||||||
|
|
||||||
with logging.FileLoggerContext(
|
with logging.FileLoggerContext(
|
||||||
os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence,
|
os.path.join(persistence_dir, self.LOG_FILENAME),
|
||||||
|
enabled=use_persistence,
|
||||||
):
|
):
|
||||||
# log initial information
|
# log initial information
|
||||||
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
|
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
|
||||||
@ -209,6 +215,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
# initialize logger
|
# initialize logger
|
||||||
full_config = self._build_config_dict()
|
full_config = self._build_config_dict()
|
||||||
full_config.update(envs.info())
|
full_config.update(envs.info())
|
||||||
|
logger: TLogger
|
||||||
if use_persistence:
|
if use_persistence:
|
||||||
logger = self.logger_factory.create_logger(
|
logger = self.logger_factory.create_logger(
|
||||||
log_dir=persistence_dir,
|
log_dir=persistence_dir,
|
||||||
@ -460,7 +467,7 @@ class _BuilderMixinCriticsFactory:
|
|||||||
|
|
||||||
|
|
||||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||||
def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None:
|
def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
|
||||||
super().__init__(1, actor_future_provider)
|
super().__init__(1, actor_future_provider)
|
||||||
|
|
||||||
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||||
@ -553,7 +560,7 @@ class _BuilderMixinCriticEnsembleFactory:
|
|||||||
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
|
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _get_critic_ensemble_factory(self):
|
def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory:
|
||||||
if self.critic_ensemble_factory is None:
|
if self.critic_ensemble_factory is None:
|
||||||
return CriticEnsembleFactoryDefault()
|
return CriticEnsembleFactoryDefault()
|
||||||
else:
|
else:
|
||||||
@ -745,8 +752,10 @@ class IQNExperimentBuilder(ExperimentBuilder):
|
|||||||
):
|
):
|
||||||
super().__init__(env_factory, experiment_config, sampling_config)
|
super().__init__(env_factory, experiment_config, sampling_config)
|
||||||
self._params: IQNParams = IQNParams()
|
self._params: IQNParams = IQNParams()
|
||||||
self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory(
|
self._preprocess_network_factory: IntermediateModuleFactory = (
|
||||||
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
IntermediateModuleFactoryFromActorFactory(
|
||||||
|
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def with_iqn_params(self, params: IQNParams) -> Self:
|
def with_iqn_params(self, params: IQNParams) -> Self:
|
||||||
|
@ -4,16 +4,20 @@ from typing import Literal, TypeAlias
|
|||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
TLogger: TypeAlias = BaseLogger
|
||||||
|
|
||||||
|
|
||||||
class LoggerFactory(ToStringMixin, ABC):
|
class LoggerFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_logger(
|
def create_logger(
|
||||||
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
|
self,
|
||||||
|
log_dir: str,
|
||||||
|
experiment_name: str,
|
||||||
|
run_id: str | None,
|
||||||
|
config_dict: dict,
|
||||||
) -> TLogger:
|
) -> TLogger:
|
||||||
""":param log_dir: path to the directory in which log data is to be stored
|
""":param log_dir: path to the directory in which log data is to be stored
|
||||||
:param experiment_name: the name of the job, which may contain os.path.sep
|
:param experiment_name: the name of the job, which may contain os.path.sep
|
||||||
@ -35,7 +39,11 @@ class DefaultLoggerFactory(LoggerFactory):
|
|||||||
self.wandb_project = wandb_project
|
self.wandb_project = wandb_project
|
||||||
|
|
||||||
def create_logger(
|
def create_logger(
|
||||||
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
|
self,
|
||||||
|
log_dir: str,
|
||||||
|
experiment_name: str,
|
||||||
|
run_id: str | None,
|
||||||
|
config_dict: dict,
|
||||||
) -> TLogger:
|
) -> TLogger:
|
||||||
writer = SummaryWriter(log_dir)
|
writer = SummaryWriter(log_dir)
|
||||||
writer.add_text(
|
writer.add_text(
|
||||||
@ -48,18 +56,18 @@ class DefaultLoggerFactory(LoggerFactory):
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
logger: TLogger
|
match self.logger_type:
|
||||||
if self.logger_type == "wandb":
|
case "wandb":
|
||||||
logger = WandbLogger(
|
wandb_logger = WandbLogger(
|
||||||
save_interval=1,
|
save_interval=1,
|
||||||
name=experiment_name.replace(os.path.sep, "__"),
|
name=experiment_name.replace(os.path.sep, "__"),
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
config=config_dict,
|
config=config_dict,
|
||||||
project=self.wandb_project,
|
project=self.wandb_project,
|
||||||
)
|
)
|
||||||
logger.load(writer)
|
wandb_logger.load(writer)
|
||||||
elif self.logger_type == "tensorboard":
|
return wandb_logger
|
||||||
logger = TensorboardLogger(writer)
|
case "tensorboard":
|
||||||
else:
|
return TensorboardLogger(writer)
|
||||||
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
case _:
|
||||||
return logger
|
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||||
|
@ -209,16 +209,16 @@ class ActorFactoryTransientStorageDecorator(ActorFactory):
|
|||||||
self.actor_factory = actor_factory
|
self.actor_factory = actor_factory
|
||||||
self._actor_future = actor_future
|
self._actor_future = actor_future
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self) -> dict:
|
||||||
d = dict(self.__dict__)
|
d = dict(self.__dict__)
|
||||||
del d["_actor_future"]
|
del d["_actor_future"]
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state: dict) -> None:
|
||||||
self.__dict__ = state
|
self.__dict__ = state
|
||||||
self._actor_future = ActorFuture()
|
self._actor_future = ActorFuture()
|
||||||
|
|
||||||
def _tostring_excludes(self):
|
def _tostring_excludes(self) -> list[str]:
|
||||||
return [*super()._tostring_excludes(), "_actor_future"]
|
return [*super()._tostring_excludes(), "_actor_future"]
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
@ -61,7 +62,7 @@ class CriticFactoryDefault(CriticFactory):
|
|||||||
envs: Environments,
|
envs: Environments,
|
||||||
device: TDevice,
|
device: TDevice,
|
||||||
use_action: bool,
|
use_action: bool,
|
||||||
discrete_last_size_use_action_shape=False,
|
discrete_last_size_use_action_shape: bool = False,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
factory: CriticFactory
|
factory: CriticFactory
|
||||||
env_type = envs.get_type()
|
env_type = envs.get_type()
|
||||||
@ -89,7 +90,7 @@ class CriticFactoryContinuousNet(CriticFactory):
|
|||||||
envs: Environments,
|
envs: Environments,
|
||||||
device: TDevice,
|
device: TDevice,
|
||||||
use_action: bool,
|
use_action: bool,
|
||||||
discrete_last_size_use_action_shape=False,
|
discrete_last_size_use_action_shape: bool = False,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
action_shape = envs.get_action_shape() if use_action else 0
|
action_shape = envs.get_action_shape() if use_action else 0
|
||||||
net_c = Net(
|
net_c = Net(
|
||||||
@ -114,7 +115,7 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
|||||||
envs: Environments,
|
envs: Environments,
|
||||||
device: TDevice,
|
device: TDevice,
|
||||||
use_action: bool,
|
use_action: bool,
|
||||||
discrete_last_size_use_action_shape=False,
|
discrete_last_size_use_action_shape: bool = False,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
action_shape = envs.get_action_shape() if use_action else 0
|
action_shape = envs.get_action_shape() if use_action else 0
|
||||||
net_c = Net(
|
net_c = Net(
|
||||||
@ -125,7 +126,9 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
|||||||
activation=nn.Tanh,
|
activation=nn.Tanh,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
|
last_size = (
|
||||||
|
int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1
|
||||||
|
)
|
||||||
critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device)
|
critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device)
|
||||||
init_linear_orthogonal(critic)
|
init_linear_orthogonal(critic)
|
||||||
return critic
|
return critic
|
||||||
@ -149,7 +152,7 @@ class CriticFactoryReuseActor(CriticFactory):
|
|||||||
envs: Environments,
|
envs: Environments,
|
||||||
device: TDevice,
|
device: TDevice,
|
||||||
use_action: bool,
|
use_action: bool,
|
||||||
discrete_last_size_use_action_shape=False,
|
discrete_last_size_use_action_shape: bool = False,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
actor = self.actor_future.actor
|
actor = self.actor_future.actor
|
||||||
if not isinstance(actor, BaseActor):
|
if not isinstance(actor, BaseActor):
|
||||||
@ -157,7 +160,10 @@ class CriticFactoryReuseActor(CriticFactory):
|
|||||||
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
|
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
|
||||||
)
|
)
|
||||||
if envs.get_type().is_discrete():
|
if envs.get_type().is_discrete():
|
||||||
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
|
# TODO get rid of this prod pattern here and elsewhere
|
||||||
|
last_size = (
|
||||||
|
int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1
|
||||||
|
)
|
||||||
return discrete.Critic(
|
return discrete.Critic(
|
||||||
actor.get_preprocess_net(),
|
actor.get_preprocess_net(),
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -45,12 +45,12 @@ class Persistence(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def restore(self, event: RestoreEvent, world: World):
|
def restore(self, event: RestoreEvent, world: World) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PersistenceGroup(Persistence):
|
class PersistenceGroup(Persistence):
|
||||||
def __init__(self, *p: Persistence, enabled=True):
|
def __init__(self, *p: Persistence, enabled: bool = True):
|
||||||
self.items = p
|
self.items = p
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class PersistenceGroup(Persistence):
|
|||||||
for item in self.items:
|
for item in self.items:
|
||||||
item.persist(event, world)
|
item.persist(event, world)
|
||||||
|
|
||||||
def restore(self, event: RestoreEvent, world: World):
|
def restore(self, event: RestoreEvent, world: World) -> None:
|
||||||
for item in self.items:
|
for item in self.items:
|
||||||
item.restore(event, world)
|
item.restore(event, world)
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ class PersistenceGroup(Persistence):
|
|||||||
class PolicyPersistence:
|
class PolicyPersistence:
|
||||||
FILENAME = "policy.dat"
|
FILENAME = "policy.dat"
|
||||||
|
|
||||||
def __init__(self, additional_persistence: Persistence | None = None, enabled=True):
|
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
|
||||||
""":param additional_persistence: a persistence instance which is to be envoked whenever
|
""":param additional_persistence: a persistence instance which is to be envoked whenever
|
||||||
this object is used to persist/restore data
|
this object is used to persist/restore data
|
||||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||||
@ -93,7 +93,7 @@ class PolicyPersistence:
|
|||||||
if self.additional_persistence is not None:
|
if self.additional_persistence is not None:
|
||||||
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
|
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
|
||||||
|
|
||||||
def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]:
|
def get_save_best_fn(self, world: World) -> Callable[[torch.nn.Module], None]:
|
||||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||||
self.persist(pol, world)
|
self.persist(pol, world)
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class World:
|
|||||||
test_collector: "Collector"
|
test_collector: "Collector"
|
||||||
logger: "TLogger"
|
logger: "TLogger"
|
||||||
persist_directory: str
|
persist_directory: str
|
||||||
restore_directory: str
|
restore_directory: str | None
|
||||||
trainer: Optional["BaseTrainer"] = None
|
trainer: Optional["BaseTrainer"] = None
|
||||||
|
|
||||||
def persist_path(self, filename: str) -> str:
|
def persist_path(self, filename: str) -> str:
|
||||||
|
@ -17,7 +17,8 @@ from tianshou.policy import BasePolicy
|
|||||||
from tianshou.policy.base import TLearningRateScheduler
|
from tianshou.policy.base import TLearningRateScheduler
|
||||||
from tianshou.utils import RunningMeanStd
|
from tianshou.utils import RunningMeanStd
|
||||||
|
|
||||||
TDistributionFunction: TypeAlias = Callable[[torch.Tensor, ...], torch.distributions.Distribution]
|
# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution]
|
||||||
|
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution]
|
||||||
|
|
||||||
|
|
||||||
class PGPolicy(BasePolicy):
|
class PGPolicy(BasePolicy):
|
||||||
|
@ -216,7 +216,7 @@ class ActorProb(BaseActor):
|
|||||||
def get_preprocess_net(self) -> nn.Module:
|
def get_preprocess_net(self) -> nn.Module:
|
||||||
return self.preprocess
|
return self.preprocess
|
||||||
|
|
||||||
def get_output_dim(self):
|
def get_output_dim(self) -> int:
|
||||||
return self.output_dim
|
return self.output_dim
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -189,7 +189,7 @@ class ImplicitQuantileNetwork(Critic):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
preprocess_net: nn.Module,
|
preprocess_net: nn.Module,
|
||||||
action_shape: Sequence[int],
|
action_shape: Sequence[int] | int,
|
||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
num_cosines: int = 64,
|
num_cosines: int = 64,
|
||||||
preprocess_net_output_dim: int | None = None,
|
preprocess_net_output_dim: int | None = None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user