Improved docstrings, added asserts to make mypy happy

This commit is contained in:
Michael Panchenko 2024-05-05 15:56:06 +02:00
parent c5d0e169b5
commit 26a6cca76e
3 changed files with 12 additions and 28 deletions

View File

@ -279,7 +279,6 @@ class BaseCollector(ABC):
:param n_episode: how many episodes you want to collect. :param n_episode: how many episodes you want to collect.
:param random: whether to use random policy for collecting data. :param random: whether to use random policy for collecting data.
:param render: the sleep time between rendering consecutive frames. :param render: the sleep time between rendering consecutive frames.
:param no_grad: whether to retain gradient in policy.forward().
:param reset_before_collect: whether to reset the environment before collecting data. :param reset_before_collect: whether to reset the environment before collecting data.
(The collector needs the initial obs and info to function properly.) (The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
@ -343,11 +342,6 @@ class Collector(BaseCollector):
# D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
# Only used in n_episode case. Then, R becomes R-S. # Only used in n_episode case. Then, R becomes R-S.
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
# policy.deterministic_eval)
def __init__( def __init__(
self, self,
policy: BasePolicy, policy: BasePolicy,

View File

@ -2,7 +2,7 @@ import logging
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Callable from collections.abc import Callable, Iterator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import asdict from dataclasses import asdict
@ -360,7 +360,10 @@ class BaseTrainer(ABC):
self.logger.log_info_data(asdict(info_stat), self.epoch) self.logger.log_info_data(asdict(info_stat), self.epoch)
# in case trainer is used with run(), epoch_stat will not be returned # in case trainer is used with run(), epoch_stat will not be returned
epoch_stat: EpochStats = EpochStats( assert (
update_stat is not None
), "Defined in the loop above, this shouldn't have happened and is likely a bug!"
return EpochStats(
epoch=self.epoch, epoch=self.epoch,
train_collect_stat=train_stat, train_collect_stat=train_stat,
test_collect_stat=test_stat, test_collect_stat=test_stat,
@ -368,8 +371,6 @@ class BaseTrainer(ABC):
info_stat=info_stat, info_stat=info_stat,
) )
return epoch_stat
def test_step(self) -> tuple[CollectStats, bool]: def test_step(self) -> tuple[CollectStats, bool]:
"""Perform one testing step.""" """Perform one testing step."""
assert self.episode_per_test is not None assert self.episode_per_test is not None
@ -407,7 +408,7 @@ class BaseTrainer(ABC):
return test_stat, stop_fn_flag return test_stat, stop_fn_flag
@contextmanager @contextmanager
def _is_within_training_step_enabled(self, is_within_training_step: bool): def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]:
old_value = self.policy.is_within_training_step old_value = self.policy.is_within_training_step
try: try:
self.policy.is_within_training_step = is_within_training_step self.policy.is_within_training_step = is_within_training_step
@ -419,10 +420,12 @@ class BaseTrainer(ABC):
with self._is_within_training_step_enabled(True): with self._is_within_training_step_enabled(True):
should_stop_training = False should_stop_training = False
collect_stats: CollectStatsBase | CollectStats
if self.train_collector is not None: if self.train_collector is not None:
collect_stats = self._collect_training_data() collect_stats = self._collect_training_data()
should_stop_training = self._test_in_train(collect_stats) should_stop_training = self._test_in_train(collect_stats)
else: else:
assert self.buffer is not None, "Either train_collector or buffer must be provided."
collect_stats = CollectStatsBase( collect_stats = CollectStatsBase(
n_collected_episodes=len(self.buffer), n_collected_episodes=len(self.buffer),
) )
@ -484,6 +487,7 @@ class BaseTrainer(ABC):
and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore
): ):
assert self.test_collector is not None assert self.test_collector is not None
assert self.episode_per_test is not None and self.episode_per_test > 0
test_result = test_episode( test_result = test_episode(
self.test_collector, self.test_collector,
self.test_fn, self.test_fn,

View File

@ -49,23 +49,9 @@ def gather_info(
) -> InfoStats: ) -> InfoStats:
"""A simple wrapper of gathering information from collectors. """A simple wrapper of gathering information from collectors.
:return: A dataclass object with the following members (depending on available collectors): :return: InfoStats object with times computed based on the `start_time` and
episode/step counts read off the collectors. No computation of
* ``gradient_step`` the total number of gradient steps; expensive statistics is done here.
* ``best_reward`` the best reward over the test results;
* ``best_reward_std`` the standard deviation of best reward over the test results;
* ``train_step`` the total collected step of training collector;
* ``train_episode`` the total collected episode of training collector;
* ``test_step`` the total collected step of test collector;
* ``test_episode`` the total collected episode of test collector;
* ``timing`` the timing statistics, with the following members:
* ``total_time`` the total time elapsed;
* ``train_time`` the total time elapsed for learning training (collecting samples plus model update);
* ``train_time_collect`` the time for collecting transitions in the \
training collector;
* ``train_time_update`` the time for training models;
* ``test_time`` the time for testing;
* ``update_speed`` the speed of updating (env_step per second).
""" """
duration = max(0.0, time.time() - start_time) duration = max(0.0, time.time() - start_time)
test_time = 0.0 test_time = 0.0