Closes #914 Additional changes: - Deprecate python below 11 - Remove 3rd party and throughput tests. This simplifies install and test pipeline - Remove gym compatibility and shimmy - Format with 3.11 conventions. In particular, add `zip(..., strict=True/False)` where possible Since the additional tests and gym were complicating the CI pipeline (flaky and dist-dependent), it didn't make sense to work on fixing the current tests in this PR to then just delete them in the next one. So this PR changes the build and removes these tests at the same time.
138 lines
5.1 KiB
Python
138 lines
5.1 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable
|
|
from numbers import Number
|
|
|
|
import numpy as np
|
|
|
|
LOG_DATA_TYPE = dict[str, int | Number | np.number | np.ndarray]
|
|
|
|
|
|
class BaseLogger(ABC):
|
|
"""The base class for any logger which is compatible with trainer.
|
|
|
|
Try to overwrite write() method to use your own writer.
|
|
|
|
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
|
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
train_interval: int = 1000,
|
|
test_interval: int = 1,
|
|
update_interval: int = 1000,
|
|
) -> None:
|
|
super().__init__()
|
|
self.train_interval = train_interval
|
|
self.test_interval = test_interval
|
|
self.update_interval = update_interval
|
|
self.last_log_train_step = -1
|
|
self.last_log_test_step = -1
|
|
self.last_log_update_step = -1
|
|
|
|
@abstractmethod
|
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
|
"""Specify how the writer is used to log data.
|
|
|
|
:param str step_type: namespace which the data dict belongs to.
|
|
:param int step: stands for the ordinate of the data dict.
|
|
:param dict data: the data to write with format ``{key: value}``.
|
|
"""
|
|
|
|
def log_train_data(self, collect_result: dict, step: int) -> None:
|
|
"""Use writer to log statistics generated during training.
|
|
|
|
:param collect_result: a dict containing information of data collected in
|
|
training stage, i.e., returns of collector.collect().
|
|
:param int step: stands for the timestep the collect_result being logged.
|
|
"""
|
|
if collect_result["n/ep"] > 0 and step - self.last_log_train_step >= self.train_interval:
|
|
log_data = {
|
|
"train/episode": collect_result["n/ep"],
|
|
"train/reward": collect_result["rew"],
|
|
"train/length": collect_result["len"],
|
|
}
|
|
self.write("train/env_step", step, log_data)
|
|
self.last_log_train_step = step
|
|
|
|
def log_test_data(self, collect_result: dict, step: int) -> None:
|
|
"""Use writer to log statistics generated during evaluating.
|
|
|
|
:param collect_result: a dict containing information of data collected in
|
|
evaluating stage, i.e., returns of collector.collect().
|
|
:param int step: stands for the timestep the collect_result being logged.
|
|
"""
|
|
assert collect_result["n/ep"] > 0
|
|
if step - self.last_log_test_step >= self.test_interval:
|
|
log_data = {
|
|
"test/env_step": step,
|
|
"test/reward": collect_result["rew"],
|
|
"test/length": collect_result["len"],
|
|
"test/reward_std": collect_result["rew_std"],
|
|
"test/length_std": collect_result["len_std"],
|
|
}
|
|
self.write("test/env_step", step, log_data)
|
|
self.last_log_test_step = step
|
|
|
|
def log_update_data(self, update_result: dict, step: int) -> None:
|
|
"""Use writer to log statistics generated during updating.
|
|
|
|
:param update_result: a dict containing information of data collected in
|
|
updating stage, i.e., returns of policy.update().
|
|
:param int step: stands for the timestep the collect_result being logged.
|
|
"""
|
|
if step - self.last_log_update_step >= self.update_interval:
|
|
log_data = {f"update/{k}": v for k, v in update_result.items()}
|
|
self.write("update/gradient_step", step, log_data)
|
|
self.last_log_update_step = step
|
|
|
|
@abstractmethod
|
|
def save_data(
|
|
self,
|
|
epoch: int,
|
|
env_step: int,
|
|
gradient_step: int,
|
|
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
|
|
) -> None:
|
|
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
|
|
|
:param int epoch: the epoch in trainer.
|
|
:param int env_step: the env_step in trainer.
|
|
:param int gradient_step: the gradient_step in trainer.
|
|
:param function save_checkpoint_fn: a hook defined by user, see trainer
|
|
documentation for detail.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def restore_data(self) -> tuple[int, int, int]:
|
|
"""Return the metadata from existing log.
|
|
|
|
If it finds nothing or an error occurs during the recover process, it will
|
|
return the default parameters.
|
|
|
|
:return: epoch, env_step, gradient_step.
|
|
"""
|
|
|
|
|
|
class LazyLogger(BaseLogger):
|
|
"""A logger that does nothing. Used as the placeholder in trainer."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
|
"""The LazyLogger writes nothing."""
|
|
|
|
def save_data(
|
|
self,
|
|
epoch: int,
|
|
env_step: int,
|
|
gradient_step: int,
|
|
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
|
|
) -> None:
|
|
pass
|
|
|
|
def restore_data(self) -> tuple[int, int, int]:
|
|
return 0, 0, 0
|