Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
148 lines
5.7 KiB
Python
148 lines
5.7 KiB
Python
import numpy as np
|
|
from numbers import Number
|
|
from typing import Any, Union
|
|
from abc import ABC, abstractmethod
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
class BaseLogger(ABC):
|
|
"""The base class for any logger which is compatible with trainer."""
|
|
|
|
def __init__(self, writer: Any) -> None:
|
|
super().__init__()
|
|
self.writer = writer
|
|
|
|
@abstractmethod
|
|
def write(
|
|
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
|
|
) -> None:
|
|
"""Specify how the writer is used to log data.
|
|
|
|
:param str key: namespace which the input data tuple belongs to.
|
|
:param int x: stands for the ordinate of the input data tuple.
|
|
:param y: stands for the abscissa of the input data tuple.
|
|
"""
|
|
pass
|
|
|
|
def log_train_data(self, collect_result: dict, step: int) -> None:
|
|
"""Use writer to log statistics generated during training.
|
|
|
|
:param collect_result: a dict containing information of data collected in
|
|
training stage, i.e., returns of collector.collect().
|
|
:param int step: stands for the timestep the collect_result being logged.
|
|
"""
|
|
pass
|
|
|
|
def log_update_data(self, update_result: dict, step: int) -> None:
|
|
"""Use writer to log statistics generated during updating.
|
|
|
|
:param update_result: a dict containing information of data collected in
|
|
updating stage, i.e., returns of policy.update().
|
|
:param int step: stands for the timestep the collect_result being logged.
|
|
"""
|
|
pass
|
|
|
|
def log_test_data(self, collect_result: dict, step: int) -> None:
|
|
"""Use writer to log statistics generated during evaluating.
|
|
|
|
:param collect_result: a dict containing information of data collected in
|
|
evaluating stage, i.e., returns of collector.collect().
|
|
:param int step: stands for the timestep the collect_result being logged.
|
|
"""
|
|
pass
|
|
|
|
|
|
class BasicLogger(BaseLogger):
|
|
"""A loggger that relies on tensorboard SummaryWriter by default to visualize \
|
|
and log statistics.
|
|
|
|
You can also rewrite write() func to use your own writer.
|
|
|
|
:param SummaryWriter writer: the writer to log data.
|
|
:param int train_interval: the log interval in log_train_data(). Default to 1.
|
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
|
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
writer: SummaryWriter,
|
|
train_interval: int = 1,
|
|
test_interval: int = 1,
|
|
update_interval: int = 1000,
|
|
) -> None:
|
|
super().__init__(writer)
|
|
self.train_interval = train_interval
|
|
self.test_interval = test_interval
|
|
self.update_interval = update_interval
|
|
self.last_log_train_step = -1
|
|
self.last_log_test_step = -1
|
|
self.last_log_update_step = -1
|
|
|
|
def write(
|
|
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
|
|
) -> None:
|
|
self.writer.add_scalar(key, y, global_step=x)
|
|
|
|
def log_train_data(self, collect_result: dict, step: int) -> None:
|
|
"""Use writer to log statistics generated during training.
|
|
|
|
:param collect_result: a dict containing information of data collected in
|
|
training stage, i.e., returns of collector.collect().
|
|
:param int step: stands for the timestep the collect_result being logged.
|
|
|
|
.. note::
|
|
|
|
``collect_result`` will be modified in-place with "rew" and "len" keys.
|
|
"""
|
|
if collect_result["n/ep"] > 0:
|
|
collect_result["rew"] = collect_result["rews"].mean()
|
|
collect_result["len"] = collect_result["lens"].mean()
|
|
if step - self.last_log_train_step >= self.train_interval:
|
|
self.write("train/n/ep", step, collect_result["n/ep"])
|
|
self.write("train/rew", step, collect_result["rew"])
|
|
self.write("train/len", step, collect_result["len"])
|
|
self.last_log_train_step = step
|
|
|
|
def log_test_data(self, collect_result: dict, step: int) -> None:
|
|
"""Use writer to log statistics generated during evaluating.
|
|
|
|
:param collect_result: a dict containing information of data collected in
|
|
evaluating stage, i.e., returns of collector.collect().
|
|
:param int step: stands for the timestep the collect_result being logged.
|
|
|
|
.. note::
|
|
|
|
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
|
|
and "len_std" keys.
|
|
"""
|
|
assert collect_result["n/ep"] > 0
|
|
rews, lens = collect_result["rews"], collect_result["lens"]
|
|
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
|
|
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
|
|
if step - self.last_log_test_step >= self.test_interval:
|
|
self.write("test/rew", step, rew)
|
|
self.write("test/len", step, len_)
|
|
self.write("test/rew_std", step, rew_std)
|
|
self.write("test/len_std", step, len_std)
|
|
self.last_log_test_step = step
|
|
|
|
def log_update_data(self, update_result: dict, step: int) -> None:
|
|
if step - self.last_log_update_step >= self.update_interval:
|
|
for k, v in update_result.items():
|
|
self.write(k, step, v)
|
|
self.last_log_update_step = step
|
|
|
|
|
|
class LazyLogger(BasicLogger):
|
|
"""A loggger that does nothing. Used as the placeholder in trainer."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(None) # type: ignore
|
|
|
|
def write(
|
|
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
|
|
) -> None:
|
|
"""The LazyLogger writes nothing."""
|
|
pass
|