Reify policy persistence, introducing Wold representation
This commit is contained in:
parent
ee3813b09c
commit
f6d49774a2
@ -1,15 +1,12 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from os import PathLike
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
|
||||
import gymnasium
|
||||
import torch
|
||||
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
)
|
||||
@ -41,7 +38,9 @@ from tianshou.highlevel.params.policy_params import (
|
||||
TRPOParams,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.persistence import PolicyPersistence
|
||||
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
|
||||
from tianshou.highlevel.world import World
|
||||
from tianshou.policy import (
|
||||
A2CPolicy,
|
||||
BasePolicy,
|
||||
@ -71,6 +70,7 @@ TDiscreteCriticOnlyParams = TypeVar(
|
||||
bound=ParamsMixinLearningRateWithScheduler,
|
||||
)
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentFactory(ABC, ToStringMixin):
|
||||
@ -133,58 +133,20 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
)
|
||||
return policy
|
||||
|
||||
@staticmethod
|
||||
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||
pass
|
||||
# TODO: Fix saving in general (code works only for mujoco)
|
||||
# state = {
|
||||
# CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
||||
# CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
||||
# }
|
||||
# torch.save(state, os.path.join(log_path, "policy.pth"))
|
||||
|
||||
return save_best_fn
|
||||
|
||||
@staticmethod
|
||||
def load_checkpoint(
|
||||
policy: torch.nn.Module,
|
||||
path: str | PathLike,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
) -> None:
|
||||
ckpt = torch.load(path, map_location=device)
|
||||
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
||||
if envs.train_envs:
|
||||
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
||||
if envs.test_envs:
|
||||
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
||||
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
||||
|
||||
@abstractmethod
|
||||
def create_trainer(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
envs: Environments,
|
||||
logger: Logger,
|
||||
) -> BaseTrainer:
|
||||
def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> BaseTrainer:
|
||||
pass
|
||||
|
||||
|
||||
class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
def create_trainer(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
envs: Environments,
|
||||
logger: Logger,
|
||||
world: World,
|
||||
policy_persistence: PolicyPersistence,
|
||||
) -> OnpolicyTrainer:
|
||||
sampling_config = self.sampling_config
|
||||
callbacks = self.trainer_callbacks
|
||||
context = TrainingContext(policy, envs, logger)
|
||||
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||
train_fn = (
|
||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_train
|
||||
@ -199,17 +161,17 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||
)
|
||||
return OnpolicyTrainer(
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
policy=world.policy,
|
||||
train_collector=world.train_collector,
|
||||
test_collector=world.test_collector,
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
batch_size=sampling_config.batch_size,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
||||
logger=logger.logger,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
logger=world.logger.logger,
|
||||
test_in_train=False,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
@ -220,15 +182,12 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
def create_trainer(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
envs: Environments,
|
||||
logger: Logger,
|
||||
world: World,
|
||||
policy_persistence: PolicyPersistence,
|
||||
) -> OffpolicyTrainer:
|
||||
sampling_config = self.sampling_config
|
||||
callbacks = self.trainer_callbacks
|
||||
context = TrainingContext(policy, envs, logger)
|
||||
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||
train_fn = (
|
||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_train
|
||||
@ -243,16 +202,16 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||
)
|
||||
return OffpolicyTrainer(
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
policy=world.policy,
|
||||
train_collector=world.train_collector,
|
||||
test_collector=world.test_collector,
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
batch_size=sampling_config.batch_size,
|
||||
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
||||
logger=logger.logger,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
logger=world.logger.logger,
|
||||
update_per_step=sampling_config.update_per_step,
|
||||
test_in_train=False,
|
||||
train_fn=train_fn,
|
||||
|
@ -39,6 +39,7 @@ from tianshou.highlevel.module.actor import (
|
||||
from tianshou.highlevel.module.core import (
|
||||
ImplicitQuantileNetworkFactory,
|
||||
IntermediateModuleFactory,
|
||||
TDevice,
|
||||
)
|
||||
from tianshou.highlevel.module.critic import (
|
||||
CriticEnsembleFactory,
|
||||
@ -63,15 +64,17 @@ from tianshou.highlevel.params.policy_params import (
|
||||
TRPOParams,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol, PolicyPersistence
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerCallbacks,
|
||||
TrainerEpochCallbackTest,
|
||||
TrainerEpochCallbackTrain,
|
||||
TrainerStopCallback,
|
||||
)
|
||||
from tianshou.highlevel.world import World
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import BaseTrainer
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -86,14 +89,18 @@ class ExperimentConfig:
|
||||
seed: int = 42
|
||||
render: float | None = 0.0
|
||||
"""Milliseconds between rendered frames; if None, no rendering"""
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
resume_id: str | None = None
|
||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||
resume_path: str | None = None
|
||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||
watch: bool = False
|
||||
"""If True, will not perform training and only watch the restored policy"""
|
||||
device: TDevice = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
"""The torch device to use"""
|
||||
policy_restore_directory: str | None = None
|
||||
"""Directory from which to load the policy neural network parameters (saved in a previous run)"""
|
||||
train: bool = True
|
||||
"""Whether to perform training"""
|
||||
watch: bool = True
|
||||
"""Whether to watch agent performance (after training)"""
|
||||
watch_num_episodes = 10
|
||||
"""Number of episodes for which to watch performance (if watch is enabled)"""
|
||||
watch_render: float = 0.0
|
||||
"""Milliseconds between rendered frames when watching agent performance (if watch is enabled)"""
|
||||
|
||||
|
||||
class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
@ -123,55 +130,71 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
# TODO
|
||||
}
|
||||
|
||||
def run(self, log_name: str) -> None:
|
||||
def run(self, experiment_name: str | None = None, logger_run_id: str | None = None) -> None:
|
||||
""":param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||
directory) where all results associated with the experiment will be saved.
|
||||
The name may contain path separators (os.path.sep, used by os.path.join), in which case
|
||||
a nested directory structure will be created.
|
||||
If None, use a name containing the current date and time.
|
||||
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
|
||||
using wandb, in particular).
|
||||
:return:
|
||||
"""
|
||||
if experiment_name is None:
|
||||
experiment_name = datetime_tag()
|
||||
|
||||
self._set_seed()
|
||||
envs = self.env_factory(self.env_config)
|
||||
policy_persistence = PolicyPersistence()
|
||||
log.info(f"Created {envs}")
|
||||
|
||||
full_config = self._build_config_dict()
|
||||
full_config.update(envs.info())
|
||||
|
||||
run_id = self.config.resume_id
|
||||
logger = self.logger_factory.create_logger(
|
||||
log_name=log_name,
|
||||
run_id=run_id,
|
||||
log_name=experiment_name,
|
||||
run_id=logger_run_id,
|
||||
config_dict=full_config,
|
||||
)
|
||||
|
||||
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||
if self.config.resume_path:
|
||||
self.agent_factory.load_checkpoint(
|
||||
policy,
|
||||
self.config.resume_path,
|
||||
envs,
|
||||
self.config.device,
|
||||
)
|
||||
|
||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
||||
policy,
|
||||
envs,
|
||||
)
|
||||
|
||||
if not self.config.watch:
|
||||
trainer = self.agent_factory.create_trainer(
|
||||
world = World(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
if self.config.policy_restore_directory:
|
||||
policy_persistence.restore(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
envs,
|
||||
logger,
|
||||
self.config.policy_restore_directory,
|
||||
self.config.device,
|
||||
)
|
||||
|
||||
if self.config.train:
|
||||
trainer = self.agent_factory.create_trainer(world, policy_persistence)
|
||||
world.trainer = trainer
|
||||
|
||||
result = trainer.run()
|
||||
pprint(result) # TODO logging
|
||||
|
||||
render = self.config.render
|
||||
if render is None:
|
||||
render = 0.0 # TODO: Perhaps we should have a second render parameter for watch mode?
|
||||
self._watch_agent(
|
||||
self.config.watch_num_episodes,
|
||||
policy,
|
||||
test_collector,
|
||||
render,
|
||||
)
|
||||
if self.config.watch:
|
||||
self._watch_agent(
|
||||
self.config.watch_num_episodes,
|
||||
policy,
|
||||
test_collector,
|
||||
self.config.watch_render,
|
||||
)
|
||||
|
||||
# TODO return result
|
||||
|
||||
@staticmethod
|
||||
def _watch_agent(
|
||||
|
@ -120,7 +120,8 @@ class ActorFactoryDefault(ActorFactory):
|
||||
return factory.create_module(envs, device)
|
||||
elif env_type == EnvType.DISCRETE:
|
||||
factory = ActorFactoryDiscreteNet(
|
||||
self.DEFAULT_HIDDEN_SIZES, softmax_output=self.discrete_softmax,
|
||||
self.DEFAULT_HIDDEN_SIZES,
|
||||
softmax_output=self.discrete_softmax,
|
||||
)
|
||||
return factory.create_module(envs, device)
|
||||
else:
|
||||
|
@ -1,5 +1,16 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Protocol, Self, runtime_checkable
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Protocol, Self, runtime_checkable
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.world import World
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -10,3 +21,50 @@ class PersistableConfigProtocol(Protocol):
|
||||
|
||||
def save(self, path: os.PathLike[str]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class Persistence(ABC):
|
||||
def path(self, world: World, filename: str) -> str:
|
||||
return os.path.join(world.directory, filename)
|
||||
|
||||
@abstractmethod
|
||||
def persist(self, world: World) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def restore(self, world: World):
|
||||
pass
|
||||
|
||||
|
||||
class PersistenceGroup(Persistence):
|
||||
def __init__(self, *p: Persistence):
|
||||
self.items = p
|
||||
|
||||
def persist(self, world: World) -> None:
|
||||
for item in self.items:
|
||||
item.persist(world)
|
||||
|
||||
def restore(self, world: World):
|
||||
for item in self.items:
|
||||
item.restore(world)
|
||||
|
||||
|
||||
class PolicyPersistence:
|
||||
FILENAME = "policy.dat"
|
||||
|
||||
def persist(self, policy: torch.nn.Module, directory: str) -> None:
|
||||
path = os.path.join(directory, self.FILENAME)
|
||||
log.info(f"Saving policy in {path}")
|
||||
torch.save(policy.state_dict(), path)
|
||||
|
||||
def restore(self, policy: torch.nn.Module, directory: str, device: "TDevice") -> None:
|
||||
path = os.path.join(directory, self.FILENAME)
|
||||
log.info(f"Restoring policy from {path}")
|
||||
state_dict = torch.load(path, map_location=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
def get_save_best_fn(self, world):
|
||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||
self.persist(pol, world.directory)
|
||||
|
||||
return save_best_fn
|
||||
|
27
tianshou/highlevel/world.py
Normal file
27
tianshou/highlevel/world.py
Normal file
@ -0,0 +1,27 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tianshou.data import Collector
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import BaseTrainer
|
||||
|
||||
|
||||
@dataclass
|
||||
class World:
|
||||
envs: "Environments"
|
||||
policy: "BasePolicy"
|
||||
train_collector: "Collector"
|
||||
test_collector: "Collector"
|
||||
logger: "Logger"
|
||||
trainer: Optional["BaseTrainer"] = None
|
||||
|
||||
@property
|
||||
def directory(self):
|
||||
return self.logger.log_path
|
||||
|
||||
def path(self, filename: str) -> str:
|
||||
return os.path.join(self.directory, filename)
|
Loading…
x
Reference in New Issue
Block a user