Michael Panchenko b900fdf6f2
Remove kwargs in policy init (#950)
Closes #947 

This removes all kwargs from all policy constructors. While doing that,
I also improved several names and added a whole lot of TODOs.

## Functional changes:

1. Added possibility to pass None as `critic2` and `critic2_optim`. In
fact, the default behavior then should cover the absolute majority of
cases
2. Added a function called `clone_optimizer` as a temporary measure to
support passing `critic2_optim=None`

## Breaking changes:

1. `action_space` is no longer optional. In fact, it already was
non-optional, as there was a ValueError in BasePolicy.init. So now
several examples were fixed to reflect that
2. `reward_normalization` removed from DDPG and children. It was never
allowed to pass it as `True` there, an error would have been raised in
`compute_n_step_reward`. Now I removed it from the interface
3. renamed `critic1` and similar to `critic`, in order to have uniform
interfaces. Note that the `critic` in DDPG was optional for the sole
reason that child classes used `critic1`. I removed this optionality
(DDPG can't do anything with `critic=None`)
4. Several renamings of fields (mostly private to public, so backwards
compatible)

## Additional changes: 
1. Removed type and default declaration from docstring. This kind of
duplication is really not necessary
2. Policy constructors are now only called using named arguments, not a
fragile mixture of positional and named as before
5. Minor beautifications in typing and code 
6. Generally shortened docstrings and made them uniform across all
policies (hopefully)

## Comment:

With these changes, several problems in tianshou's inheritance hierarchy
become more apparent. I tried highlighting them for future work.

---------

Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 08:57:03 -07:00

138 lines
5.0 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Callable
from numbers import Number
import numpy as np
LOG_DATA_TYPE = dict[str, int | Number | np.number | np.ndarray]
class BaseLogger(ABC):
"""The base class for any logger which is compatible with trainer.
Try to overwrite write() method to use your own writer.
:param train_interval: the log interval in log_train_data(). Default to 1000.
:param test_interval: the log interval in log_test_data(). Default to 1.
:param update_interval: the log interval in log_update_data(). Default to 1000.
"""
def __init__(
self,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
) -> None:
super().__init__()
self.train_interval = train_interval
self.test_interval = test_interval
self.update_interval = update_interval
self.last_log_train_step = -1
self.last_log_test_step = -1
self.last_log_update_step = -1
@abstractmethod
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
"""Specify how the writer is used to log data.
:param str step_type: namespace which the data dict belongs to.
:param step: stands for the ordinate of the data dict.
:param data: the data to write with format ``{key: value}``.
"""
def log_train_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during training.
:param collect_result: a dict containing information of data collected in
training stage, i.e., returns of collector.collect().
:param step: stands for the timestep the collect_result being logged.
"""
if collect_result["n/ep"] > 0 and step - self.last_log_train_step >= self.train_interval:
log_data = {
"train/episode": collect_result["n/ep"],
"train/reward": collect_result["rew"],
"train/length": collect_result["len"],
}
self.write("train/env_step", step, log_data)
self.last_log_train_step = step
def log_test_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during evaluating.
:param collect_result: a dict containing information of data collected in
evaluating stage, i.e., returns of collector.collect().
:param step: stands for the timestep the collect_result being logged.
"""
assert collect_result["n/ep"] > 0
if step - self.last_log_test_step >= self.test_interval:
log_data = {
"test/env_step": step,
"test/reward": collect_result["rew"],
"test/length": collect_result["len"],
"test/reward_std": collect_result["rew_std"],
"test/length_std": collect_result["len_std"],
}
self.write("test/env_step", step, log_data)
self.last_log_test_step = step
def log_update_data(self, update_result: dict, step: int) -> None:
"""Use writer to log statistics generated during updating.
:param update_result: a dict containing information of data collected in
updating stage, i.e., returns of policy.update().
:param step: stands for the timestep the collect_result being logged.
"""
if step - self.last_log_update_step >= self.update_interval:
log_data = {f"update/{k}": v for k, v in update_result.items()}
self.write("update/gradient_step", step, log_data)
self.last_log_update_step = step
@abstractmethod
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
:param epoch: the epoch in trainer.
:param env_step: the env_step in trainer.
:param gradient_step: the gradient_step in trainer.
:param function save_checkpoint_fn: a hook defined by user, see trainer
documentation for detail.
"""
@abstractmethod
def restore_data(self) -> tuple[int, int, int]:
"""Return the metadata from existing log.
If it finds nothing or an error occurs during the recover process, it will
return the default parameters.
:return: epoch, env_step, gradient_step.
"""
class LazyLogger(BaseLogger):
"""A logger that does nothing. Used as the placeholder in trainer."""
def __init__(self) -> None:
super().__init__()
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
"""The LazyLogger writes nothing."""
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
pass
def restore_data(self) -> tuple[int, int, int]:
return 0, 0, 0