format, type check and small fixes
This commit is contained in:
parent
f730782f29
commit
d9a612a997
@ -66,7 +66,13 @@ def main(
|
|||||||
replay_buffer_save_only_last_obs=True,
|
replay_buffer_save_only_last_obs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
|
env_factory = AtariEnvFactory(
|
||||||
|
task,
|
||||||
|
sampling_config.train_seed,
|
||||||
|
sampling_config.test_seed,
|
||||||
|
frames_stack,
|
||||||
|
scale=scale_obs,
|
||||||
|
)
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -65,7 +65,13 @@ def main(
|
|||||||
replay_buffer_save_only_last_obs=True,
|
replay_buffer_save_only_last_obs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
|
env_factory = AtariEnvFactory(
|
||||||
|
task,
|
||||||
|
sampling_config.train_seed,
|
||||||
|
sampling_config.test_seed,
|
||||||
|
frames_stack,
|
||||||
|
scale=scale_obs,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -65,7 +65,13 @@ def main(
|
|||||||
replay_buffer_save_only_last_obs=True,
|
replay_buffer_save_only_last_obs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = AtariEnvFactory(task, sampling_config.train_seed, sampling_config.test_seed, frames_stack, scale=scale_obs)
|
env_factory = AtariEnvFactory(
|
||||||
|
task,
|
||||||
|
sampling_config.train_seed,
|
||||||
|
sampling_config.test_seed,
|
||||||
|
frames_stack,
|
||||||
|
scale=scale_obs,
|
||||||
|
)
|
||||||
|
|
||||||
builder = (
|
builder = (
|
||||||
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -54,7 +54,12 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -51,7 +51,12 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -56,7 +56,12 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -61,7 +61,12 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from functools import partial
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from examples.mujoco.tools import eval_results, RLiableExperimentResult
|
from examples.mujoco.tools import RLiableExperimentResult, eval_results
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
@ -19,7 +18,6 @@ from tianshou.highlevel.params.dist_fn import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
from tianshou.utils import logging
|
|
||||||
from tianshou.utils.logging import datetime_tag
|
from tianshou.utils.logging import datetime_tag
|
||||||
|
|
||||||
|
|
||||||
@ -65,7 +63,12 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
experiments = (
|
experiments = (
|
||||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
@ -102,13 +105,13 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
def eval_experiments(log_dir: str):
|
def eval_experiments(log_dir: str):
|
||||||
results = RLiableExperimentResult.load_from_disk(log_dir, 'PPO', None)
|
results = RLiableExperimentResult.load_from_disk(log_dir, "PPO", None)
|
||||||
eval_results(results)
|
eval_results(results)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# logging.run_cli(main)
|
# logging.run_cli(main)
|
||||||
experiment_config = ExperimentConfig(watch=False)
|
# experiment_config = ExperimentConfig(watch=False)
|
||||||
log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
|
# log_dir = logging.run_main(partial(main, experiment_config, epoch=2))
|
||||||
# log_dir = <path/to/exp>
|
log_dir = "log/Ant-v4/ppo/20240312-114646"
|
||||||
eval_experiments(log_dir)
|
eval_experiments(log_dir)
|
||||||
|
@ -57,7 +57,12 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -49,7 +49,12 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -52,7 +52,12 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -58,7 +58,12 @@ def main(
|
|||||||
start_timesteps_random=True,
|
start_timesteps_random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=False)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -58,7 +58,12 @@ def main(
|
|||||||
repeat_per_collect=repeat_per_collect,
|
repeat_per_collect=repeat_per_collect,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True)
|
env_factory = MujocoEnvFactory(
|
||||||
|
task,
|
||||||
|
train_seed=sampling_config.train_seed,
|
||||||
|
test_seed=sampling_config.test_seed,
|
||||||
|
obs_norm=True,
|
||||||
|
)
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
|
@ -5,7 +5,7 @@ import csv
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import asdict, dataclass
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -34,26 +34,33 @@ class RLiableExperimentResult:
|
|||||||
test_episode_returns = []
|
test_episode_returns = []
|
||||||
|
|
||||||
for entry in os.scandir(exp_dir):
|
for entry in os.scandir(exp_dir):
|
||||||
if entry.name.startswith('.'):
|
if entry.name.startswith(".") or not entry.is_dir():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
exp = Experiment.from_directory(entry.path)
|
exp = Experiment.from_directory(entry.path)
|
||||||
logger = exp.logger_factory.create_logger(entry.path, entry.name, None, asdict(exp.config))
|
logger = exp.logger_factory.create_logger(
|
||||||
|
entry.path,
|
||||||
|
entry.name,
|
||||||
|
None,
|
||||||
|
asdict(exp.config),
|
||||||
|
)
|
||||||
data = logger.restore_logged_data(entry.path)
|
data = logger.restore_logged_data(entry.path)
|
||||||
|
|
||||||
test_data = data['test']
|
test_data = data["test"]
|
||||||
|
|
||||||
test_episode_returns.append(test_data['returns_stat']['mean'])
|
test_episode_returns.append(test_data["returns_stat"]["mean"])
|
||||||
env_step = test_data['env_step']
|
env_step = test_data["env_step"]
|
||||||
|
|
||||||
if score_thresholds is None:
|
if score_thresholds is None:
|
||||||
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
|
score_thresholds = np.linspace(0.0, np.max(test_episode_returns), 101)
|
||||||
|
|
||||||
return RLiableExperimentResult(algorithms=[algo_name],
|
return RLiableExperimentResult(
|
||||||
|
algorithms=[algo_name],
|
||||||
score_dict={algo_name: np.array(test_episode_returns)},
|
score_dict={algo_name: np.array(test_episode_returns)},
|
||||||
env_steps=np.array(env_step),
|
env_steps=np.array(env_step),
|
||||||
score_thresholds=score_thresholds,
|
score_thresholds=score_thresholds,
|
||||||
exp_dir=exp_dir)
|
exp_dir=exp_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def eval_results(results: RLiableExperimentResult):
|
def eval_results(results: RLiableExperimentResult):
|
||||||
@ -64,31 +71,44 @@ def eval_results(results: RLiableExperimentResult):
|
|||||||
from rliable import plot_utils
|
from rliable import plot_utils
|
||||||
|
|
||||||
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
|
||||||
iqm_scores, iqm_cis = rly.get_interval_estimates(
|
iqm_scores, iqm_cis = rly.get_interval_estimates(results.score_dict, iqm, reps=50000)
|
||||||
results.score_dict, iqm, reps=50000)
|
|
||||||
|
|
||||||
# Plot IQM sample efficiency curve
|
# Plot IQM sample efficiency curve
|
||||||
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
|
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
|
||||||
plot_utils.plot_sample_efficiency_curve(
|
plot_utils.plot_sample_efficiency_curve(
|
||||||
results.env_steps, iqm_scores, iqm_cis, algorithms=results.algorithms,
|
results.env_steps,
|
||||||
xlabel=r'Number of env steps',
|
iqm_scores,
|
||||||
ylabel='IQM episode return',
|
iqm_cis,
|
||||||
ax=ax)
|
algorithms=results.algorithms,
|
||||||
plt.savefig(os.path.join(results.exp_dir, 'iqm_sample_efficiency_curve.png'))
|
xlabel=r"Number of env steps",
|
||||||
|
ylabel="IQM episode return",
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
plt.savefig(os.path.join(results.exp_dir, "iqm_sample_efficiency_curve.png"))
|
||||||
|
|
||||||
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
|
final_score_dict = {algo: returns[:, [-1]] for algo, returns in results.score_dict.items()}
|
||||||
score_distributions, score_distributions_cis = rly.create_performance_profile(
|
score_distributions, score_distributions_cis = rly.create_performance_profile(
|
||||||
final_score_dict, results.score_thresholds)
|
final_score_dict,
|
||||||
|
results.score_thresholds,
|
||||||
|
)
|
||||||
|
|
||||||
# Plot score distributions
|
# Plot score distributions
|
||||||
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
|
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
|
||||||
plot_utils.plot_performance_profiles(
|
plot_utils.plot_performance_profiles(
|
||||||
score_distributions, results.score_thresholds,
|
score_distributions,
|
||||||
|
results.score_thresholds,
|
||||||
performance_profile_cis=score_distributions_cis,
|
performance_profile_cis=score_distributions_cis,
|
||||||
colors=dict(zip(results.algorithms, sns.color_palette('colorblind'))),
|
colors=dict(
|
||||||
xlabel=r'Episode return $(\tau)$',
|
zip(
|
||||||
ax=ax)
|
results.algorithms,
|
||||||
plt.savefig(os.path.join(results.exp_dir, 'performance_profile.png'))
|
sns.color_palette("colorblind", n_colors=len(results.algorithms)),
|
||||||
|
strict=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
xlabel=r"Episode return $(\tau)$",
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
plt.savefig(os.path.join(results.exp_dir, "performance_profile.png"))
|
||||||
|
|
||||||
|
|
||||||
def find_all_files(root_dir, pattern):
|
def find_all_files(root_dir, pattern):
|
||||||
|
@ -6,9 +6,19 @@ from tianshou.highlevel.env import (
|
|||||||
|
|
||||||
class DiscreteTestEnvFactory(EnvFactoryRegistered):
|
class DiscreteTestEnvFactory(EnvFactoryRegistered):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(task="CartPole-v0", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)
|
super().__init__(
|
||||||
|
task="CartPole-v0",
|
||||||
|
train_seed=42,
|
||||||
|
test_seed=1337,
|
||||||
|
venv_type=VectorEnvType.DUMMY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ContinuousTestEnvFactory(EnvFactoryRegistered):
|
class ContinuousTestEnvFactory(EnvFactoryRegistered):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(task="Pendulum-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY)
|
super().__init__(
|
||||||
|
task="Pendulum-v1",
|
||||||
|
train_seed=42,
|
||||||
|
test_seed=1337,
|
||||||
|
venv_type=VectorEnvType.DUMMY,
|
||||||
|
)
|
||||||
|
@ -3,7 +3,7 @@ import pickle
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import asdict, dataclass
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
@ -87,9 +87,7 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def shortener(input_string: str | None = None,
|
def shortener(input_string: str | None = None, length: int = 1) -> str:
|
||||||
length: int = 1
|
|
||||||
):
|
|
||||||
"""Shorten the input string by keeping only the first `length` characters of each word.
|
"""Shorten the input string by keeping only the first `length` characters of each word.
|
||||||
|
|
||||||
If the input string is None or empty, return "default".
|
If the input string is None or empty, return "default".
|
||||||
@ -367,19 +365,19 @@ class ExperimentBuilder:
|
|||||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def experiment_config(self):
|
def experiment_config(self) -> ExperimentConfig:
|
||||||
return self._config
|
return self._config
|
||||||
|
|
||||||
@experiment_config.setter
|
@experiment_config.setter
|
||||||
def experiment_config(self, experiment_config: ExperimentConfig):
|
def experiment_config(self, experiment_config: ExperimentConfig) -> None:
|
||||||
self._config = experiment_config
|
self._config = experiment_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sampling_config(self):
|
def sampling_config(self) -> SamplingConfig:
|
||||||
return self._sampling_config
|
return self._sampling_config
|
||||||
|
|
||||||
@sampling_config.setter
|
@sampling_config.setter
|
||||||
def sampling_config(self, sampling_config: SamplingConfig):
|
def sampling_config(self, sampling_config: SamplingConfig) -> None:
|
||||||
self._sampling_config = sampling_config
|
self._sampling_config = sampling_config
|
||||||
|
|
||||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||||
@ -491,7 +489,6 @@ class ExperimentBuilder:
|
|||||||
|
|
||||||
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
|
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
configured_experiment_config = copy(self.experiment_config)
|
configured_experiment_config = copy(self.experiment_config)
|
||||||
configured_experiment_seed = configured_experiment_config.seed
|
configured_experiment_seed = configured_experiment_config.seed
|
||||||
configured_sampling_config = copy(self.sampling_config)
|
configured_sampling_config = copy(self.sampling_config)
|
||||||
@ -512,9 +509,13 @@ class ExperimentBuilder:
|
|||||||
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
|
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
|
||||||
exp = self.build()
|
exp = self.build()
|
||||||
|
|
||||||
full_name = ",".join([f"experiment_seed={exp.config.seed}",
|
full_name = ",".join(
|
||||||
|
[
|
||||||
|
f"experiment_seed={exp.config.seed}",
|
||||||
f"train_seed={exp.sampling_config.train_seed}",
|
f"train_seed={exp.sampling_config.train_seed}",
|
||||||
f"test_seed={exp.sampling_config.test_seed}"])
|
f"test_seed={exp.sampling_config.test_seed}",
|
||||||
|
],
|
||||||
|
)
|
||||||
experiment_name = shortener(full_name, 4)
|
experiment_name = shortener(full_name, 4)
|
||||||
seeded_experiments[experiment_name] = exp
|
seeded_experiments[experiment_name] = exp
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ from typing import Literal, TypeAlias
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
|
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
|
||||||
from tianshou.utils.logger.base import LoggerManager
|
|
||||||
from tianshou.utils.logger.pandas_logger import PandasLogger
|
from tianshou.utils.logger.pandas_logger import PandasLogger
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
@ -78,38 +77,3 @@ class LoggerFactoryDefault(LoggerFactory):
|
|||||||
return TensorboardLogger(writer)
|
return TensorboardLogger(writer)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||||
|
|
||||||
|
|
||||||
class LoggerManagerFactory(LoggerFactory):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"],
|
|
||||||
wandb_project: str | None = None,
|
|
||||||
):
|
|
||||||
self.logger_types = logger_types
|
|
||||||
self.wandb_project = wandb_project
|
|
||||||
|
|
||||||
self.factories = {
|
|
||||||
"pandas": LoggerFactoryDefault(logger_type="pandas"),
|
|
||||||
"wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project),
|
|
||||||
"tensorboard": LoggerFactoryDefault(logger_type="tensorboard"),
|
|
||||||
}
|
|
||||||
|
|
||||||
def create_logger(
|
|
||||||
self,
|
|
||||||
log_dir: str,
|
|
||||||
experiment_name: str,
|
|
||||||
run_id: str | None,
|
|
||||||
config_dict: dict,
|
|
||||||
) -> TLogger:
|
|
||||||
logger_manager = LoggerManager()
|
|
||||||
for logger_type in self.logger_types:
|
|
||||||
logger_manager.loggers.append(
|
|
||||||
self.factories[logger_type].create_logger(
|
|
||||||
log_dir,
|
|
||||||
experiment_name,
|
|
||||||
run_id,
|
|
||||||
config_dict,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return logger_manager
|
|
||||||
|
@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -60,7 +59,7 @@ class BaseLogger(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_dict_for_logging(log_data: dict) -> dict:
|
def prepare_dict_for_logging(log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||||
return log_data
|
return log_data
|
||||||
|
|
||||||
def log_train_data(self, log_data: dict, step: int) -> None:
|
def log_train_data(self, log_data: dict, step: int) -> None:
|
||||||
@ -72,7 +71,7 @@ class BaseLogger(ABC):
|
|||||||
# TODO: move interval check to calling method
|
# TODO: move interval check to calling method
|
||||||
if step - self.last_log_train_step >= self.train_interval:
|
if step - self.last_log_train_step >= self.train_interval:
|
||||||
log_data = self.prepare_dict_for_logging(log_data)
|
log_data = self.prepare_dict_for_logging(log_data)
|
||||||
self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data)
|
self.write(f"{DataScope.TRAIN.value}/env_step", step, log_data)
|
||||||
self.last_log_train_step = step
|
self.last_log_train_step = step
|
||||||
|
|
||||||
def log_test_data(self, log_data: dict, step: int) -> None:
|
def log_test_data(self, log_data: dict, step: int) -> None:
|
||||||
@ -84,7 +83,7 @@ class BaseLogger(ABC):
|
|||||||
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
|
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
|
||||||
if step - self.last_log_test_step >= self.test_interval:
|
if step - self.last_log_test_step >= self.test_interval:
|
||||||
log_data = self.prepare_dict_for_logging(log_data)
|
log_data = self.prepare_dict_for_logging(log_data)
|
||||||
self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data)
|
self.write(f"{DataScope.TEST.value}/env_step", step, log_data)
|
||||||
self.last_log_test_step = step
|
self.last_log_test_step = step
|
||||||
|
|
||||||
def log_update_data(self, log_data: dict, step: int) -> None:
|
def log_update_data(self, log_data: dict, step: int) -> None:
|
||||||
@ -96,7 +95,7 @@ class BaseLogger(ABC):
|
|||||||
# TODO: move interval check to calling method
|
# TODO: move interval check to calling method
|
||||||
if step - self.last_log_update_step >= self.update_interval:
|
if step - self.last_log_update_step >= self.update_interval:
|
||||||
log_data = self.prepare_dict_for_logging(log_data)
|
log_data = self.prepare_dict_for_logging(log_data)
|
||||||
self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data)
|
self.write(f"{DataScope.UPDATE.value}/gradient_step", step, log_data)
|
||||||
self.last_log_update_step = step
|
self.last_log_update_step = step
|
||||||
|
|
||||||
def log_info_data(self, log_data: dict, step: int) -> None:
|
def log_info_data(self, log_data: dict, step: int) -> None:
|
||||||
@ -109,7 +108,7 @@ class BaseLogger(ABC):
|
|||||||
step - self.last_log_info_step >= self.info_interval
|
step - self.last_log_info_step >= self.info_interval
|
||||||
): # TODO: move interval check to calling method
|
): # TODO: move interval check to calling method
|
||||||
log_data = self.prepare_dict_for_logging(log_data)
|
log_data = self.prepare_dict_for_logging(log_data)
|
||||||
self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data)
|
self.write(f"{DataScope.INFO.value}/epoch", step, log_data)
|
||||||
self.last_log_info_step = step
|
self.last_log_info_step = step
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -139,10 +138,12 @@ class BaseLogger(ABC):
|
|||||||
:return: epoch, env_step, gradient_step.
|
:return: epoch, env_step, gradient_step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def restore_logged_data(log_path):
|
def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||||
"""Load the logged data from dist for post-processing."""
|
"""Load the logged data from disk for post-processing.
|
||||||
|
|
||||||
|
:return: a dict containing the logged data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LazyLogger(BaseLogger):
|
class LazyLogger(BaseLogger):
|
||||||
@ -166,57 +167,5 @@ class LazyLogger(BaseLogger):
|
|||||||
def restore_data(self) -> tuple[int, int, int]:
|
def restore_data(self) -> tuple[int, int, int]:
|
||||||
return 0, 0, 0
|
return 0, 0, 0
|
||||||
|
|
||||||
def restore_logged_data(self):
|
def restore_logged_data(self, log_path: str) -> dict:
|
||||||
return None
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class LoggerManager(BaseLogger):
|
|
||||||
"""A container of loggers that holds more than one logger."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.loggers = []
|
|
||||||
|
|
||||||
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
|
|
||||||
for logger in self.loggers:
|
|
||||||
data_copy = data.copy()
|
|
||||||
logger.write(step_type, step, data_copy)
|
|
||||||
|
|
||||||
def log_train_data(self, log_data: dict, step: int) -> None:
|
|
||||||
for logger in self.loggers:
|
|
||||||
logger.log_train_data(log_data, step)
|
|
||||||
|
|
||||||
def log_test_data(self, log_data: dict, step: int) -> None:
|
|
||||||
for logger in self.loggers:
|
|
||||||
logger.log_test_data(log_data, step)
|
|
||||||
|
|
||||||
def log_update_data(self, log_data: dict, step: int) -> None:
|
|
||||||
for logger in self.loggers:
|
|
||||||
logger.log_update_data(log_data, step)
|
|
||||||
|
|
||||||
def log_info_data(self, log_data: dict, step: int) -> None:
|
|
||||||
for logger in self.loggers:
|
|
||||||
logger.log_info_data(log_data, step)
|
|
||||||
|
|
||||||
def save_data(
|
|
||||||
self,
|
|
||||||
epoch: int,
|
|
||||||
env_step: int,
|
|
||||||
gradient_step: int,
|
|
||||||
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
for logger in self.loggers:
|
|
||||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
|
||||||
|
|
||||||
def restore_data(self) -> tuple[int, int, int]:
|
|
||||||
for logger in self.loggers:
|
|
||||||
epoch, env_step, gradient_step = logger.restore_data()
|
|
||||||
|
|
||||||
self.last_save_step = self.last_log_test_step = epoch
|
|
||||||
self.last_log_update_step = gradient_step
|
|
||||||
self.last_log_train_step = env_step
|
|
||||||
|
|
||||||
return epoch, env_step, gradient_step
|
|
||||||
|
|
||||||
def restore_logged_data(self, log_path):
|
|
||||||
return self.loggers[0].restore_logged_data(log_path)
|
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
from collections.abc import Callable
|
||||||
from typing import Callable, Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from matplotlib.figure import Figure
|
|
||||||
|
|
||||||
from tianshou.utils import BaseLogger, logging
|
from tianshou.utils import BaseLogger, logging
|
||||||
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE
|
from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE
|
||||||
|
|
||||||
|
|
||||||
class PandasLogger(BaseLogger):
|
class PandasLogger(BaseLogger):
|
||||||
@ -21,20 +19,47 @@ class PandasLogger(BaseLogger):
|
|||||||
info_interval: int = 1,
|
info_interval: int = 1,
|
||||||
exclude_arrays: bool = True,
|
exclude_arrays: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays)
|
super().__init__(
|
||||||
|
train_interval,
|
||||||
|
test_interval,
|
||||||
|
update_interval,
|
||||||
|
info_interval,
|
||||||
|
exclude_arrays,
|
||||||
|
)
|
||||||
self.log_path = log_dir
|
self.log_path = log_dir
|
||||||
self.csv_name = os.path.join(self.log_path, "log.csv")
|
self.csv_name = os.path.join(self.log_path, "log.csv")
|
||||||
self.pkl_name = os.path.join(self.log_path, "log.pkl")
|
self.pkl_name = os.path.join(self.log_path, "log.pkl")
|
||||||
self.data = defaultdict(list)
|
self.data: dict[str, list] = defaultdict(list)
|
||||||
self.last_save_step = -1
|
self.last_save_step = -1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_dict_for_logging(data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||||
|
"""Removes invalid data types from the log data."""
|
||||||
|
filtered_dict = data.copy()
|
||||||
|
|
||||||
|
def filter_dict(data_dict: dict[str, Any]) -> None:
|
||||||
|
"""Filter in place."""
|
||||||
|
for key, value in data_dict.items():
|
||||||
|
if isinstance(value, VALID_LOG_VALS):
|
||||||
|
filter_dict(value)
|
||||||
|
else:
|
||||||
|
filtered_dict.pop(key)
|
||||||
|
|
||||||
|
filter_dict(data)
|
||||||
|
return filtered_dict
|
||||||
|
|
||||||
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
|
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
|
||||||
scope, step_name = step_type.split("/")
|
scope, step_name = step_type.split("/")
|
||||||
data[step_name] = step
|
data[step_name] = step
|
||||||
self.data[scope].append(data)
|
self.data[scope].append(data)
|
||||||
|
|
||||||
def save_data(self, epoch: int, env_step: int, gradient_step: int,
|
def save_data(
|
||||||
save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None:
|
self,
|
||||||
|
epoch: int,
|
||||||
|
env_step: int,
|
||||||
|
gradient_step: int,
|
||||||
|
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
|
||||||
|
) -> None:
|
||||||
self.last_save_step = epoch
|
self.last_save_step = epoch
|
||||||
# create and dump a dataframe
|
# create and dump a dataframe
|
||||||
for k, v in self.data.items():
|
for k, v in self.data.items():
|
||||||
@ -45,7 +70,13 @@ class PandasLogger(BaseLogger):
|
|||||||
def restore_data(self) -> tuple[int, int, int]:
|
def restore_data(self) -> tuple[int, int, int]:
|
||||||
for scope in ["train", "test", "update", "info"]:
|
for scope in ["train", "test", "update", "info"]:
|
||||||
try:
|
try:
|
||||||
self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values()))
|
self.data[scope].extend(
|
||||||
|
list(
|
||||||
|
pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl"))
|
||||||
|
.T.to_dict()
|
||||||
|
.values(),
|
||||||
|
),
|
||||||
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logging.warning(f"Failed to restore {scope} data")
|
logging.warning(f"Failed to restore {scope} data")
|
||||||
|
|
||||||
@ -67,12 +98,11 @@ class PandasLogger(BaseLogger):
|
|||||||
|
|
||||||
return epoch, env_step, gradient_step
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
@staticmethod
|
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
|
||||||
def restore_logged_data(log_path):
|
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
def merge_dicts(list_of_dicts):
|
def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
|
||||||
result = defaultdict(list)
|
result: dict[str, Any] = defaultdict(list)
|
||||||
for d in list_of_dicts:
|
for d in list_of_dicts:
|
||||||
for key, value in d.items():
|
for key, value in d.items():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
@ -85,22 +115,10 @@ class PandasLogger(BaseLogger):
|
|||||||
|
|
||||||
for scope in ["train", "test", "update", "info"]:
|
for scope in ["train", "test", "update", "info"]:
|
||||||
try:
|
try:
|
||||||
dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values())
|
dict_list = list(
|
||||||
|
pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values(),
|
||||||
|
)
|
||||||
data[scope] = merge_dicts(dict_list)
|
data[scope] = merge_dicts(dict_list)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logging.warning(f"Failed to restore {scope} data")
|
logging.warning(f"Failed to restore {scope} data")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
|
|
||||||
"""Filter out matplotlib figures from the data."""
|
|
||||||
filtered_dict = data.copy()
|
|
||||||
|
|
||||||
def filter_dict(d):
|
|
||||||
for key, value in d.items():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
filter_dict(value)
|
|
||||||
elif isinstance(value, Figure):
|
|
||||||
filtered_dict.pop(key)
|
|
||||||
|
|
||||||
filter_dict(data)
|
|
||||||
return filtered_dict
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -7,7 +6,7 @@ from matplotlib.figure import Figure
|
|||||||
from tensorboard.backend.event_processing import event_accumulator
|
from tensorboard.backend.event_processing import event_accumulator
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS
|
from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE, BaseLogger
|
||||||
from tianshou.utils.warning import deprecation
|
from tianshou.utils.warning import deprecation
|
||||||
|
|
||||||
|
|
||||||
@ -86,7 +85,7 @@ class TensorboardLogger(BaseLogger):
|
|||||||
scope, step_name = step_type.split("/")
|
scope, step_name = step_type.split("/")
|
||||||
self.writer.add_scalar(step_type, step, global_step=step)
|
self.writer.add_scalar(step_type, step, global_step=step)
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
scope_key = '/'.join([scope, k])
|
scope_key = f"{scope}/{k}"
|
||||||
if isinstance(v, np.ndarray):
|
if isinstance(v, np.ndarray):
|
||||||
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto")
|
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto")
|
||||||
elif isinstance(v, Figure):
|
elif isinstance(v, Figure):
|
||||||
@ -133,20 +132,19 @@ class TensorboardLogger(BaseLogger):
|
|||||||
|
|
||||||
return epoch, env_step, gradient_step
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
@staticmethod
|
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
|
||||||
def restore_logged_data(log_path):
|
|
||||||
ea = event_accumulator.EventAccumulator(log_path)
|
ea = event_accumulator.EventAccumulator(log_path)
|
||||||
ea.Reload()
|
ea.Reload()
|
||||||
|
|
||||||
def add_to_dict(dictionary, keys, value):
|
def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
|
||||||
current_dict = dictionary
|
current_dict = data_dict
|
||||||
for key in keys[:-1]:
|
for key in keys[:-1]:
|
||||||
current_dict = current_dict.setdefault(key, {})
|
current_dict = current_dict.setdefault(key, {})
|
||||||
current_dict[keys[-1]] = value
|
current_dict[keys[-1]] = value
|
||||||
|
|
||||||
data = {}
|
data: dict[str, Any] = {}
|
||||||
for key in ea.scalars.Keys():
|
for key in ea.scalars.Keys():
|
||||||
split_keys = key.split('/')
|
split_keys = key.split("/")
|
||||||
add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)]))
|
add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)]))
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
@ -156,3 +156,7 @@ class WandbLogger(BaseLogger):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
env_step = 0
|
env_step = 0
|
||||||
return epoch, env_step, gradient_step
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
|
def restore_logged_data(self, log_path: str) -> dict:
|
||||||
|
assert self.tensorboard_logger is not None
|
||||||
|
return self.tensorboard_logger.restore_logged_data(log_path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user