Reify policy persistence, introducing Wold representation

This commit is contained in:
Dominik Jain 2023-10-11 19:31:26 +02:00
parent ee3813b09c
commit f6d49774a2
5 changed files with 166 additions and 98 deletions

View File

@ -1,15 +1,12 @@
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from os import PathLike
from typing import Any, Generic, TypeVar, cast from typing import Any, Generic, TypeVar, cast
import gymnasium import gymnasium
import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
ActorFactory, ActorFactory,
) )
@ -41,7 +38,9 @@ from tianshou.highlevel.params.policy_params import (
TRPOParams, TRPOParams,
) )
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PolicyPersistence
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
from tianshou.highlevel.world import World
from tianshou.policy import ( from tianshou.policy import (
A2CPolicy, A2CPolicy,
BasePolicy, BasePolicy,
@ -71,6 +70,7 @@ TDiscreteCriticOnlyParams = TypeVar(
bound=ParamsMixinLearningRateWithScheduler, bound=ParamsMixinLearningRateWithScheduler,
) )
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
log = logging.getLogger(__name__)
class AgentFactory(ABC, ToStringMixin): class AgentFactory(ABC, ToStringMixin):
@ -133,58 +133,20 @@ class AgentFactory(ABC, ToStringMixin):
) )
return policy return policy
@staticmethod
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
def save_best_fn(pol: torch.nn.Module) -> None:
pass
# TODO: Fix saving in general (code works only for mujoco)
# state = {
# CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
# CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
# }
# torch.save(state, os.path.join(log_path, "policy.pth"))
return save_best_fn
@staticmethod
def load_checkpoint(
policy: torch.nn.Module,
path: str | PathLike,
envs: Environments,
device: TDevice,
) -> None:
ckpt = torch.load(path, map_location=device)
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
if envs.train_envs:
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
if envs.test_envs:
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
print("Loaded agent and obs. running means from: ", path) # TODO logging
@abstractmethod @abstractmethod
def create_trainer( def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> BaseTrainer:
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> BaseTrainer:
pass pass
class OnpolicyAgentFactory(AgentFactory, ABC): class OnpolicyAgentFactory(AgentFactory, ABC):
def create_trainer( def create_trainer(
self, self,
policy: BasePolicy, world: World,
train_collector: Collector, policy_persistence: PolicyPersistence,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> OnpolicyTrainer: ) -> OnpolicyTrainer:
sampling_config = self.sampling_config sampling_config = self.sampling_config
callbacks = self.trainer_callbacks callbacks = self.trainer_callbacks
context = TrainingContext(policy, envs, logger) context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = ( train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context) callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train if callbacks.epoch_callback_train
@ -199,17 +161,17 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
) )
return OnpolicyTrainer( return OnpolicyTrainer(
policy=policy, policy=world.policy,
train_collector=train_collector, train_collector=world.train_collector,
test_collector=test_collector, test_collector=world.test_collector,
max_epoch=sampling_config.num_epochs, max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch, step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect, repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs, episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size, batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect, step_per_collect=sampling_config.step_per_collect,
save_best_fn=self._create_save_best_fn(envs, logger.log_path), save_best_fn=policy_persistence.get_save_best_fn(world),
logger=logger.logger, logger=world.logger.logger,
test_in_train=False, test_in_train=False,
train_fn=train_fn, train_fn=train_fn,
test_fn=test_fn, test_fn=test_fn,
@ -220,15 +182,12 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
class OffpolicyAgentFactory(AgentFactory, ABC): class OffpolicyAgentFactory(AgentFactory, ABC):
def create_trainer( def create_trainer(
self, self,
policy: BasePolicy, world: World,
train_collector: Collector, policy_persistence: PolicyPersistence,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> OffpolicyTrainer: ) -> OffpolicyTrainer:
sampling_config = self.sampling_config sampling_config = self.sampling_config
callbacks = self.trainer_callbacks callbacks = self.trainer_callbacks
context = TrainingContext(policy, envs, logger) context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = ( train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context) callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train if callbacks.epoch_callback_train
@ -243,16 +202,16 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
) )
return OffpolicyTrainer( return OffpolicyTrainer(
policy=policy, policy=world.policy,
train_collector=train_collector, train_collector=world.train_collector,
test_collector=test_collector, test_collector=world.test_collector,
max_epoch=sampling_config.num_epochs, max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch, step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect, step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs, episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size, batch_size=sampling_config.batch_size,
save_best_fn=self._create_save_best_fn(envs, logger.log_path), save_best_fn=policy_persistence.get_save_best_fn(world),
logger=logger.logger, logger=world.logger.logger,
update_per_step=sampling_config.update_per_step, update_per_step=sampling_config.update_per_step,
test_in_train=False, test_in_train=False,
train_fn=train_fn, train_fn=train_fn,

View File

@ -39,6 +39,7 @@ from tianshou.highlevel.module.actor import (
from tianshou.highlevel.module.core import ( from tianshou.highlevel.module.core import (
ImplicitQuantileNetworkFactory, ImplicitQuantileNetworkFactory,
IntermediateModuleFactory, IntermediateModuleFactory,
TDevice,
) )
from tianshou.highlevel.module.critic import ( from tianshou.highlevel.module.critic import (
CriticEnsembleFactory, CriticEnsembleFactory,
@ -63,15 +64,17 @@ from tianshou.highlevel.params.policy_params import (
TRPOParams, TRPOParams,
) )
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence
from tianshou.highlevel.trainer import ( from tianshou.highlevel.trainer import (
TrainerCallbacks, TrainerCallbacks,
TrainerEpochCallbackTest, TrainerEpochCallbackTest,
TrainerEpochCallbackTrain, TrainerEpochCallbackTrain,
TrainerStopCallback, TrainerStopCallback,
) )
from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
from tianshou.utils.logging import datetime_tag
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -86,14 +89,18 @@ class ExperimentConfig:
seed: int = 42 seed: int = 42
render: float | None = 0.0 render: float | None = 0.0
"""Milliseconds between rendered frames; if None, no rendering""" """Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu" device: TDevice = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: str | None = None """The torch device to use"""
"""For restoring a model and running means of env-specifics from a checkpoint""" policy_restore_directory: str | None = None
resume_path: str | None = None """Directory from which to load the policy neural network parameters (saved in a previous run)"""
"""For restoring a model and running means of env-specifics from a checkpoint""" train: bool = True
watch: bool = False """Whether to perform training"""
"""If True, will not perform training and only watch the restored policy""" watch: bool = True
"""Whether to watch agent performance (after training)"""
watch_num_episodes = 10 watch_num_episodes = 10
"""Number of episodes for which to watch performance (if watch is enabled)"""
watch_render: float = 0.0
"""Milliseconds between rendered frames when watching agent performance (if watch is enabled)"""
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
@ -123,55 +130,71 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
# TODO # TODO
} }
def run(self, log_name: str) -> None: def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved.
The name may contain path separators (os.path.sep, used by os.path.join), in which case
a nested directory structure will be created.
If None, use a name containing the current date and time.
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
using wandb, in particular).
:return:
"""
if experiment_name is None:
experiment_name = datetime_tag()
self._set_seed() self._set_seed()
envs = self.env_factory(self.env_config) envs = self.env_factory(self.env_config)
policy_persistence = PolicyPersistence()
log.info(f"Created {envs}") log.info(f"Created {envs}")
full_config = self._build_config_dict() full_config = self._build_config_dict()
full_config.update(envs.info()) full_config.update(envs.info())
run_id = self.config.resume_id
logger = self.logger_factory.create_logger( logger = self.logger_factory.create_logger(
log_name=log_name, log_name=experiment_name,
run_id=run_id, run_id=logger_run_id,
config_dict=full_config, config_dict=full_config,
) )
policy = self.agent_factory.create_policy(envs, self.config.device) policy = self.agent_factory.create_policy(envs, self.config.device)
if self.config.resume_path:
self.agent_factory.load_checkpoint(
policy,
self.config.resume_path,
envs,
self.config.device,
)
train_collector, test_collector = self.agent_factory.create_train_test_collector( train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy, policy,
envs, envs,
) )
if not self.config.watch: world = World(
trainer = self.agent_factory.create_trainer( envs=envs,
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
logger=logger,
)
if self.config.policy_restore_directory:
policy_persistence.restore(
policy, policy,
train_collector, self.config.policy_restore_directory,
test_collector, self.config.device,
envs,
logger,
) )
if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer
result = trainer.run() result = trainer.run()
pprint(result) # TODO logging pprint(result) # TODO logging
render = self.config.render if self.config.watch:
if render is None: self._watch_agent(
render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode? self.config.watch_num_episodes,
self._watch_agent( policy,
self.config.watch_num_episodes, test_collector,
policy, self.config.watch_render,
test_collector, )
render,
) # TODO return result
@staticmethod @staticmethod
def _watch_agent( def _watch_agent(

View File

@ -120,7 +120,8 @@ class ActorFactoryDefault(ActorFactory):
return factory.create_module(envs, device) return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE: elif env_type == EnvType.DISCRETE:
factory = ActorFactoryDiscreteNet( factory = ActorFactoryDiscreteNet(
self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax, self.DEFAULT_HIDDEN_SIZES,
softmax_output=self.discrete_softmax,
) )
return factory.create_module(envs, device) return factory.create_module(envs, device)
else: else:

View File

@ -1,5 +1,16 @@
import logging
import os import os
from typing import Protocol, Self, runtime_checkable from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
import torch
from tianshou.highlevel.world import World
if TYPE_CHECKING:
from tianshou.highlevel.module.core import TDevice
log = logging.getLogger(__name__)
@runtime_checkable @runtime_checkable
@ -10,3 +21,50 @@ class PersistableConfigProtocol(Protocol):
def save(self, path: os.PathLike[str]) -> None: def save(self, path: os.PathLike[str]) -> None:
pass pass
class Persistence(ABC):
def path(self, world: World, filename: str) -> str:
return os.path.join(world.directory, filename)
@abstractmethod
def persist(self, world: World) -> None:
pass
@abstractmethod
def restore(self, world: World):
pass
class PersistenceGroup(Persistence):
def __init__(self, *p: Persistence):
self.items = p
def persist(self, world: World) -> None:
for item in self.items:
item.persist(world)
def restore(self, world: World):
for item in self.items:
item.restore(world)
class PolicyPersistence:
FILENAME = "policy.dat"
def persist(self, policy: torch.nn.Module, directory: str) -> None:
path = os.path.join(directory, self.FILENAME)
log.info(f"Saving policy in {path}")
torch.save(policy.state_dict(), path)
def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None:
path = os.path.join(directory, self.FILENAME)
log.info(f"Restoring policy from {path}")
state_dict = torch.load(path, map_location=device)
policy.load_state_dict(state_dict)
def get_save_best_fn(self, world):
def save_best_fn(pol: torch.nn.Module) -> None:
self.persist(pol, world.directory)
return save_best_fn

View File

@ -0,0 +1,27 @@
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tianshou.data import Collector
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
@dataclass
class World:
envs: "Environments"
policy: "BasePolicy"
train_collector: "Collector"
test_collector: "Collector"
logger: "Logger"
trainer: Optional["BaseTrainer"] = None
@property
def directory(self):
return self.logger.log_path
def path(self, filename: str) -> str:
return os.path.join(self.directory, filename)