Reify policy persistence, introducing Wold representation

This commit is contained in:
Dominik Jain 2023-10-11 19:31:26 +02:00
parent ee3813b09c
commit f6d49774a2
5 changed files with 166 additions and 98 deletions

View File

@ -1,15 +1,12 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from os import PathLike
from typing import Any, Generic, TypeVar, cast
import gymnasium
import torch
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module.actor import (
ActorFactory,
)
@ -41,7 +38,9 @@ from tianshou.highlevel.params.policy_params import (
TRPOParams,
)
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PolicyPersistence
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
from tianshou.highlevel.world import World
from tianshou.policy import (
A2CPolicy,
BasePolicy,
@ -71,6 +70,7 @@ TDiscreteCriticOnlyParams = TypeVar(
bound=ParamsMixinLearningRateWithScheduler,
)
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
log = logging.getLogger(__name__)
class AgentFactory(ABC, ToStringMixin):
@ -133,58 +133,20 @@ class AgentFactory(ABC, ToStringMixin):
)
return policy
@staticmethod
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
def save_best_fn(pol: torch.nn.Module) -> None:
pass
# TODO: Fix saving in general (code works only for mujoco)
# state = {
# CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
# CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
# }
# torch.save(state, os.path.join(log_path, "policy.pth"))
return save_best_fn
@staticmethod
def load_checkpoint(
policy: torch.nn.Module,
path: str | PathLike,
envs: Environments,
device: TDevice,
) -> None:
ckpt = torch.load(path, map_location=device)
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
if envs.train_envs:
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
if envs.test_envs:
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
print("Loaded agent and obs. running means from: ", path) # TODO logging
@abstractmethod
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> BaseTrainer:
def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> BaseTrainer:
pass
class OnpolicyAgentFactory(AgentFactory, ABC):
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
world: World,
policy_persistence: PolicyPersistence,
) -> OnpolicyTrainer:
sampling_config = self.sampling_config
callbacks = self.trainer_callbacks
context = TrainingContext(policy, envs, logger)
context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train
@ -199,17 +161,17 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
)
return OnpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
policy=world.policy,
train_collector=world.train_collector,
test_collector=world.test_collector,
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
logger=logger.logger,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger.logger,
test_in_train=False,
train_fn=train_fn,
test_fn=test_fn,
@ -220,15 +182,12 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
class OffpolicyAgentFactory(AgentFactory, ABC):
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
world: World,
policy_persistence: PolicyPersistence,
) -> OffpolicyTrainer:
sampling_config = self.sampling_config
callbacks = self.trainer_callbacks
context = TrainingContext(policy, envs, logger)
context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train
@ -243,16 +202,16 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
)
return OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
policy=world.policy,
train_collector=world.train_collector,
test_collector=world.test_collector,
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
logger=logger.logger,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger.logger,
update_per_step=sampling_config.update_per_step,
test_in_train=False,
train_fn=train_fn,

View File

@ -39,6 +39,7 @@ from tianshou.highlevel.module.actor import (
from tianshou.highlevel.module.core import (
ImplicitQuantileNetworkFactory,
IntermediateModuleFactory,
TDevice,
)
from tianshou.highlevel.module.critic import (
CriticEnsembleFactory,
@ -63,15 +64,17 @@ from tianshou.highlevel.params.policy_params import (
TRPOParams,
)
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence
from tianshou.highlevel.trainer import (
TrainerCallbacks,
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainerStopCallback,
)
from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
from tianshou.utils.logging import datetime_tag
from tianshou.utils.string import ToStringMixin
log = logging.getLogger(__name__)
@ -86,14 +89,18 @@ class ExperimentConfig:
seed: int = 42
render: float | None = 0.0
"""Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
resume_path: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
watch: bool = False
"""If True, will not perform training and only watch the restored policy"""
device: TDevice = "cuda" if torch.cuda.is_available() else "cpu"
"""The torch device to use"""
policy_restore_directory: str | None = None
"""Directory from which to load the policy neural network parameters (saved in a previous run)"""
train: bool = True
"""Whether to perform training"""
watch: bool = True
"""Whether to watch agent performance (after training)"""
watch_num_episodes = 10
"""Number of episodes for which to watch performance (if watch is enabled)"""
watch_render: float = 0.0
"""Milliseconds between rendered frames when watching agent performance (if watch is enabled)"""
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
@ -123,55 +130,71 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
# TODO
}
def run(self, log_name: str) -> None:
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved.
The name may contain path separators (os.path.sep, used by os.path.join), in which case
a nested directory structure will be created.
If None, use a name containing the current date and time.
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
using wandb, in particular).
:return:
"""
if experiment_name is None:
experiment_name = datetime_tag()
self._set_seed()
envs = self.env_factory(self.env_config)
policy_persistence = PolicyPersistence()
log.info(f"Created {envs}")
full_config = self._build_config_dict()
full_config.update(envs.info())
run_id = self.config.resume_id
logger = self.logger_factory.create_logger(
log_name=log_name,
run_id=run_id,
log_name=experiment_name,
run_id=logger_run_id,
config_dict=full_config,
)
policy = self.agent_factory.create_policy(envs, self.config.device)
if self.config.resume_path:
self.agent_factory.load_checkpoint(
policy,
self.config.resume_path,
envs,
self.config.device,
)
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
)
if not self.config.watch:
trainer = self.agent_factory.create_trainer(
world = World(
envs=envs,
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
logger=logger,
)
if self.config.policy_restore_directory:
policy_persistence.restore(
policy,
train_collector,
test_collector,
envs,
logger,
self.config.policy_restore_directory,
self.config.device,
)
if self.config.train:
trainer = self.agent_factory.create_trainer(world, policy_persistence)
world.trainer = trainer
result = trainer.run()
pprint(result) # TODO logging
render = self.config.render
if render is None:
render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode?
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
render,
)
if self.config.watch:
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.config.watch_render,
)
# TODO return result
@staticmethod
def _watch_agent(

View File

@ -120,7 +120,8 @@ class ActorFactoryDefault(ActorFactory):
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
factory = ActorFactoryDiscreteNet(
self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax,
self.DEFAULT_HIDDEN_SIZES,
softmax_output=self.discrete_softmax,
)
return factory.create_module(envs, device)
else:

View File

@ -1,5 +1,16 @@
import logging
import os
from typing import Protocol, Self, runtime_checkable
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
import torch
from tianshou.highlevel.world import World
if TYPE_CHECKING:
from tianshou.highlevel.module.core import TDevice
log = logging.getLogger(__name__)
@runtime_checkable
@ -10,3 +21,50 @@ class PersistableConfigProtocol(Protocol):
def save(self, path: os.PathLike[str]) -> None:
pass
class Persistence(ABC):
def path(self, world: World, filename: str) -> str:
return os.path.join(world.directory, filename)
@abstractmethod
def persist(self, world: World) -> None:
pass
@abstractmethod
def restore(self, world: World):
pass
class PersistenceGroup(Persistence):
def __init__(self, *p: Persistence):
self.items = p
def persist(self, world: World) -> None:
for item in self.items:
item.persist(world)
def restore(self, world: World):
for item in self.items:
item.restore(world)
class PolicyPersistence:
FILENAME = "policy.dat"
def persist(self, policy: torch.nn.Module, directory: str) -> None:
path = os.path.join(directory, self.FILENAME)
log.info(f"Saving policy in {path}")
torch.save(policy.state_dict(), path)
def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None:
path = os.path.join(directory, self.FILENAME)
log.info(f"Restoring policy from {path}")
state_dict = torch.load(path, map_location=device)
policy.load_state_dict(state_dict)
def get_save_best_fn(self, world):
def save_best_fn(pol: torch.nn.Module) -> None:
self.persist(pol, world.directory)
return save_best_fn

View File

@ -0,0 +1,27 @@
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tianshou.data import Collector
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
@dataclass
class World:
envs: "Environments"
policy: "BasePolicy"
train_collector: "Collector"
test_collector: "Collector"
logger: "Logger"
trainer: Optional["BaseTrainer"] = None
@property
def directory(self):
return self.logger.log_path
def path(self, filename: str) -> str:
return os.path.join(self.directory, filename)