Reify policy persistence, introducing Wold representation
This commit is contained in:
parent
ee3813b09c
commit
f6d49774a2
@ -1,15 +1,12 @@
|
|||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
|
||||||
from os import PathLike
|
|
||||||
from typing import Any, Generic, TypeVar, cast
|
from typing import Any, Generic, TypeVar, cast
|
||||||
|
|
||||||
import gymnasium
|
import gymnasium
|
||||||
import torch
|
|
||||||
|
|
||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import Logger
|
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
)
|
)
|
||||||
@ -41,7 +38,9 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
TRPOParams,
|
TRPOParams,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
|
from tianshou.highlevel.persistence import PolicyPersistence
|
||||||
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
|
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
|
||||||
|
from tianshou.highlevel.world import World
|
||||||
from tianshou.policy import (
|
from tianshou.policy import (
|
||||||
A2CPolicy,
|
A2CPolicy,
|
||||||
BasePolicy,
|
BasePolicy,
|
||||||
@ -71,6 +70,7 @@ TDiscreteCriticOnlyParams = TypeVar(
|
|||||||
bound=ParamsMixinLearningRateWithScheduler,
|
bound=ParamsMixinLearningRateWithScheduler,
|
||||||
)
|
)
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AgentFactory(ABC, ToStringMixin):
|
class AgentFactory(ABC, ToStringMixin):
|
||||||
@ -133,58 +133,20 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
)
|
)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
|
||||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
|
||||||
pass
|
|
||||||
# TODO: Fix saving in general (code works only for mujoco)
|
|
||||||
# state = {
|
|
||||||
# CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
|
||||||
# CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
|
||||||
# }
|
|
||||||
# torch.save(state, os.path.join(log_path, "policy.pth"))
|
|
||||||
|
|
||||||
return save_best_fn
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_checkpoint(
|
|
||||||
policy: torch.nn.Module,
|
|
||||||
path: str | PathLike,
|
|
||||||
envs: Environments,
|
|
||||||
device: TDevice,
|
|
||||||
) -> None:
|
|
||||||
ckpt = torch.load(path, map_location=device)
|
|
||||||
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
|
||||||
if envs.train_envs:
|
|
||||||
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
||||||
if envs.test_envs:
|
|
||||||
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
||||||
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_trainer(
|
def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> BaseTrainer:
|
||||||
self,
|
|
||||||
policy: BasePolicy,
|
|
||||||
train_collector: Collector,
|
|
||||||
test_collector: Collector,
|
|
||||||
envs: Environments,
|
|
||||||
logger: Logger,
|
|
||||||
) -> BaseTrainer:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OnpolicyAgentFactory(AgentFactory, ABC):
|
class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||||
def create_trainer(
|
def create_trainer(
|
||||||
self,
|
self,
|
||||||
policy: BasePolicy,
|
world: World,
|
||||||
train_collector: Collector,
|
policy_persistence: PolicyPersistence,
|
||||||
test_collector: Collector,
|
|
||||||
envs: Environments,
|
|
||||||
logger: Logger,
|
|
||||||
) -> OnpolicyTrainer:
|
) -> OnpolicyTrainer:
|
||||||
sampling_config = self.sampling_config
|
sampling_config = self.sampling_config
|
||||||
callbacks = self.trainer_callbacks
|
callbacks = self.trainer_callbacks
|
||||||
context = TrainingContext(policy, envs, logger)
|
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||||
train_fn = (
|
train_fn = (
|
||||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||||
if callbacks.epoch_callback_train
|
if callbacks.epoch_callback_train
|
||||||
@ -199,17 +161,17 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||||
)
|
)
|
||||||
return OnpolicyTrainer(
|
return OnpolicyTrainer(
|
||||||
policy=policy,
|
policy=world.policy,
|
||||||
train_collector=train_collector,
|
train_collector=world.train_collector,
|
||||||
test_collector=test_collector,
|
test_collector=world.test_collector,
|
||||||
max_epoch=sampling_config.num_epochs,
|
max_epoch=sampling_config.num_epochs,
|
||||||
step_per_epoch=sampling_config.step_per_epoch,
|
step_per_epoch=sampling_config.step_per_epoch,
|
||||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||||
episode_per_test=sampling_config.num_test_envs,
|
episode_per_test=sampling_config.num_test_envs,
|
||||||
batch_size=sampling_config.batch_size,
|
batch_size=sampling_config.batch_size,
|
||||||
step_per_collect=sampling_config.step_per_collect,
|
step_per_collect=sampling_config.step_per_collect,
|
||||||
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||||
logger=logger.logger,
|
logger=world.logger.logger,
|
||||||
test_in_train=False,
|
test_in_train=False,
|
||||||
train_fn=train_fn,
|
train_fn=train_fn,
|
||||||
test_fn=test_fn,
|
test_fn=test_fn,
|
||||||
@ -220,15 +182,12 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
class OffpolicyAgentFactory(AgentFactory, ABC):
|
class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||||
def create_trainer(
|
def create_trainer(
|
||||||
self,
|
self,
|
||||||
policy: BasePolicy,
|
world: World,
|
||||||
train_collector: Collector,
|
policy_persistence: PolicyPersistence,
|
||||||
test_collector: Collector,
|
|
||||||
envs: Environments,
|
|
||||||
logger: Logger,
|
|
||||||
) -> OffpolicyTrainer:
|
) -> OffpolicyTrainer:
|
||||||
sampling_config = self.sampling_config
|
sampling_config = self.sampling_config
|
||||||
callbacks = self.trainer_callbacks
|
callbacks = self.trainer_callbacks
|
||||||
context = TrainingContext(policy, envs, logger)
|
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||||
train_fn = (
|
train_fn = (
|
||||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||||
if callbacks.epoch_callback_train
|
if callbacks.epoch_callback_train
|
||||||
@ -243,16 +202,16 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||||
)
|
)
|
||||||
return OffpolicyTrainer(
|
return OffpolicyTrainer(
|
||||||
policy=policy,
|
policy=world.policy,
|
||||||
train_collector=train_collector,
|
train_collector=world.train_collector,
|
||||||
test_collector=test_collector,
|
test_collector=world.test_collector,
|
||||||
max_epoch=sampling_config.num_epochs,
|
max_epoch=sampling_config.num_epochs,
|
||||||
step_per_epoch=sampling_config.step_per_epoch,
|
step_per_epoch=sampling_config.step_per_epoch,
|
||||||
step_per_collect=sampling_config.step_per_collect,
|
step_per_collect=sampling_config.step_per_collect,
|
||||||
episode_per_test=sampling_config.num_test_envs,
|
episode_per_test=sampling_config.num_test_envs,
|
||||||
batch_size=sampling_config.batch_size,
|
batch_size=sampling_config.batch_size,
|
||||||
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||||
logger=logger.logger,
|
logger=world.logger.logger,
|
||||||
update_per_step=sampling_config.update_per_step,
|
update_per_step=sampling_config.update_per_step,
|
||||||
test_in_train=False,
|
test_in_train=False,
|
||||||
train_fn=train_fn,
|
train_fn=train_fn,
|
||||||
|
@ -39,6 +39,7 @@ from tianshou.highlevel.module.actor import (
|
|||||||
from tianshou.highlevel.module.core import (
|
from tianshou.highlevel.module.core import (
|
||||||
ImplicitQuantileNetworkFactory,
|
ImplicitQuantileNetworkFactory,
|
||||||
IntermediateModuleFactory,
|
IntermediateModuleFactory,
|
||||||
|
TDevice,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.module.critic import (
|
from tianshou.highlevel.module.critic import (
|
||||||
CriticEnsembleFactory,
|
CriticEnsembleFactory,
|
||||||
@ -63,15 +64,17 @@ from tianshou.highlevel.params.policy_params import (
|
|||||||
TRPOParams,
|
TRPOParams,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence
|
||||||
from tianshou.highlevel.trainer import (
|
from tianshou.highlevel.trainer import (
|
||||||
TrainerCallbacks,
|
TrainerCallbacks,
|
||||||
TrainerEpochCallbackTest,
|
TrainerEpochCallbackTest,
|
||||||
TrainerEpochCallbackTrain,
|
TrainerEpochCallbackTrain,
|
||||||
TrainerStopCallback,
|
TrainerStopCallback,
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.world import World
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
|
from tianshou.utils.logging import datetime_tag
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -86,14 +89,18 @@ class ExperimentConfig:
|
|||||||
seed: int = 42
|
seed: int = 42
|
||||||
render: float | None = 0.0
|
render: float | None = 0.0
|
||||||
"""Milliseconds between rendered frames; if None, no rendering"""
|
"""Milliseconds between rendered frames; if None, no rendering"""
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
device: TDevice = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
resume_id: str | None = None
|
"""The torch device to use"""
|
||||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
policy_restore_directory: str | None = None
|
||||||
resume_path: str | None = None
|
"""Directory from which to load the policy neural network parameters (saved in a previous run)"""
|
||||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
train: bool = True
|
||||||
watch: bool = False
|
"""Whether to perform training"""
|
||||||
"""If True, will not perform training and only watch the restored policy"""
|
watch: bool = True
|
||||||
|
"""Whether to watch agent performance (after training)"""
|
||||||
watch_num_episodes = 10
|
watch_num_episodes = 10
|
||||||
|
"""Number of episodes for which to watch performance (if watch is enabled)"""
|
||||||
|
watch_render: float = 0.0
|
||||||
|
"""Milliseconds between rendered frames when watching agent performance (if watch is enabled)"""
|
||||||
|
|
||||||
|
|
||||||
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||||
@ -123,56 +130,72 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
def run(self, log_name: str) -> None:
|
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
||||||
|
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||||
|
directory) where all results associated with the experiment will be saved.
|
||||||
|
The name may contain path separators (os.path.sep, used by os.path.join), in which case
|
||||||
|
a nested directory structure will be created.
|
||||||
|
If None, use a name containing the current date and time.
|
||||||
|
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
|
||||||
|
using wandb, in particular).
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if experiment_name is None:
|
||||||
|
experiment_name = datetime_tag()
|
||||||
|
|
||||||
self._set_seed()
|
self._set_seed()
|
||||||
envs = self.env_factory(self.env_config)
|
envs = self.env_factory(self.env_config)
|
||||||
|
policy_persistence = PolicyPersistence()
|
||||||
log.info(f"Created {envs}")
|
log.info(f"Created {envs}")
|
||||||
|
|
||||||
full_config = self._build_config_dict()
|
full_config = self._build_config_dict()
|
||||||
full_config.update(envs.info())
|
full_config.update(envs.info())
|
||||||
|
|
||||||
run_id = self.config.resume_id
|
|
||||||
logger = self.logger_factory.create_logger(
|
logger = self.logger_factory.create_logger(
|
||||||
log_name=log_name,
|
log_name=experiment_name,
|
||||||
run_id=run_id,
|
run_id=logger_run_id,
|
||||||
config_dict=full_config,
|
config_dict=full_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
policy = self.agent_factory.create_policy(envs, self.config.device)
|
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||||
if self.config.resume_path:
|
|
||||||
self.agent_factory.load_checkpoint(
|
|
||||||
policy,
|
|
||||||
self.config.resume_path,
|
|
||||||
envs,
|
|
||||||
self.config.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
||||||
policy,
|
policy,
|
||||||
envs,
|
envs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.config.watch:
|
world = World(
|
||||||
trainer = self.agent_factory.create_trainer(
|
envs=envs,
|
||||||
policy,
|
policy=policy,
|
||||||
train_collector,
|
train_collector=train_collector,
|
||||||
test_collector,
|
test_collector=test_collector,
|
||||||
envs,
|
logger=logger,
|
||||||
logger,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.policy_restore_directory:
|
||||||
|
policy_persistence.restore(
|
||||||
|
policy,
|
||||||
|
self.config.policy_restore_directory,
|
||||||
|
self.config.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.train:
|
||||||
|
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
||||||
|
world.trainer = trainer
|
||||||
|
|
||||||
result = trainer.run()
|
result = trainer.run()
|
||||||
pprint(result) # TODO logging
|
pprint(result) # TODO logging
|
||||||
|
|
||||||
render = self.config.render
|
if self.config.watch:
|
||||||
if render is None:
|
|
||||||
render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode?
|
|
||||||
self._watch_agent(
|
self._watch_agent(
|
||||||
self.config.watch_num_episodes,
|
self.config.watch_num_episodes,
|
||||||
policy,
|
policy,
|
||||||
test_collector,
|
test_collector,
|
||||||
render,
|
self.config.watch_render,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _watch_agent(
|
def _watch_agent(
|
||||||
num_episodes: int,
|
num_episodes: int,
|
||||||
|
@ -120,7 +120,8 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
return factory.create_module(envs, device)
|
return factory.create_module(envs, device)
|
||||||
elif env_type == EnvType.DISCRETE:
|
elif env_type == EnvType.DISCRETE:
|
||||||
factory = ActorFactoryDiscreteNet(
|
factory = ActorFactoryDiscreteNet(
|
||||||
self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax,
|
self.DEFAULT_HIDDEN_SIZES,
|
||||||
|
softmax_output=self.discrete_softmax,
|
||||||
)
|
)
|
||||||
return factory.create_module(envs, device)
|
return factory.create_module(envs, device)
|
||||||
else:
|
else:
|
||||||
|
@ -1,5 +1,16 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Protocol, Self, runtime_checkable
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tianshou.highlevel.world import World
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tianshou.highlevel.module.core import TDevice
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@ -10,3 +21,50 @@ class PersistableConfigProtocol(Protocol):
|
|||||||
|
|
||||||
def save(self, path: os.PathLike[str]) -> None:
|
def save(self, path: os.PathLike[str]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Persistence(ABC):
|
||||||
|
def path(self, world: World, filename: str) -> str:
|
||||||
|
return os.path.join(world.directory, filename)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def persist(self, world: World) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def restore(self, world: World):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PersistenceGroup(Persistence):
|
||||||
|
def __init__(self, *p: Persistence):
|
||||||
|
self.items = p
|
||||||
|
|
||||||
|
def persist(self, world: World) -> None:
|
||||||
|
for item in self.items:
|
||||||
|
item.persist(world)
|
||||||
|
|
||||||
|
def restore(self, world: World):
|
||||||
|
for item in self.items:
|
||||||
|
item.restore(world)
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyPersistence:
|
||||||
|
FILENAME = "policy.dat"
|
||||||
|
|
||||||
|
def persist(self, policy: torch.nn.Module, directory: str) -> None:
|
||||||
|
path = os.path.join(directory, self.FILENAME)
|
||||||
|
log.info(f"Saving policy in {path}")
|
||||||
|
torch.save(policy.state_dict(), path)
|
||||||
|
|
||||||
|
def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None:
|
||||||
|
path = os.path.join(directory, self.FILENAME)
|
||||||
|
log.info(f"Restoring policy from {path}")
|
||||||
|
state_dict = torch.load(path, map_location=device)
|
||||||
|
policy.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def get_save_best_fn(self, world):
|
||||||
|
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||||
|
self.persist(pol, world.directory)
|
||||||
|
|
||||||
|
return save_best_fn
|
||||||
|
27
tianshou/highlevel/world.py
Normal file
27
tianshou/highlevel/world.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tianshou.data import Collector
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.logger import Logger
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
from tianshou.trainer import BaseTrainer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class World:
|
||||||
|
envs: "Environments"
|
||||||
|
policy: "BasePolicy"
|
||||||
|
train_collector: "Collector"
|
||||||
|
test_collector: "Collector"
|
||||||
|
logger: "Logger"
|
||||||
|
trainer: Optional["BaseTrainer"] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def directory(self):
|
||||||
|
return self.logger.log_path
|
||||||
|
|
||||||
|
def path(self, filename: str) -> str:
|
||||||
|
return os.path.join(self.directory, filename)
|
Loading…
x
Reference in New Issue
Block a user