Improved docstrings, added asserts to make mypy happy

This commit is contained in:
Michael Panchenko 2024-05-05 15:56:06 +02:00
parent c5d0e169b5
commit 26a6cca76e
3 changed files with 12 additions and 28 deletions

View File

@ -279,7 +279,6 @@ class BaseCollector(ABC):
:param n_episode: how many episodes you want to collect.
:param random: whether to use random policy for collecting data.
:param render: the sleep time between rendering consecutive frames.
:param no_grad: whether to retain gradient in policy.forward().
:param reset_before_collect: whether to reset the environment before collecting data.
(The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
@ -343,11 +342,6 @@ class Collector(BaseCollector):
# D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
# Only used in n_episode case. Then, R becomes R-S.
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
# policy.deterministic_eval)
def __init__(
self,
policy: BasePolicy,

View File

@ -2,7 +2,7 @@ import logging
import time
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Callable
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from dataclasses import asdict
@ -360,7 +360,10 @@ class BaseTrainer(ABC):
self.logger.log_info_data(asdict(info_stat), self.epoch)
# in case trainer is used with run(), epoch_stat will not be returned
epoch_stat: EpochStats = EpochStats(
assert (
update_stat is not None
), "Defined in the loop above, this shouldn't have happened and is likely a bug!"
return EpochStats(
epoch=self.epoch,
train_collect_stat=train_stat,
test_collect_stat=test_stat,
@ -368,8 +371,6 @@ class BaseTrainer(ABC):
info_stat=info_stat,
)
return epoch_stat
def test_step(self) -> tuple[CollectStats, bool]:
"""Perform one testing step."""
assert self.episode_per_test is not None
@ -407,7 +408,7 @@ class BaseTrainer(ABC):
return test_stat, stop_fn_flag
@contextmanager
def _is_within_training_step_enabled(self, is_within_training_step: bool):
def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]:
old_value = self.policy.is_within_training_step
try:
self.policy.is_within_training_step = is_within_training_step
@ -419,10 +420,12 @@ class BaseTrainer(ABC):
with self._is_within_training_step_enabled(True):
should_stop_training = False
collect_stats: CollectStatsBase | CollectStats
if self.train_collector is not None:
collect_stats = self._collect_training_data()
should_stop_training = self._test_in_train(collect_stats)
else:
assert self.buffer is not None, "Either train_collector or buffer must be provided."
collect_stats = CollectStatsBase(
n_collected_episodes=len(self.buffer),
)
@ -484,6 +487,7 @@ class BaseTrainer(ABC):
and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore
):
assert self.test_collector is not None
assert self.episode_per_test is not None and self.episode_per_test > 0
test_result = test_episode(
self.test_collector,
self.test_fn,

View File

@ -49,23 +49,9 @@ def gather_info(
) -> InfoStats:
"""A simple wrapper of gathering information from collectors.
:return: A dataclass object with the following members (depending on available collectors):
* ``gradient_step`` the total number of gradient steps;
* ``best_reward`` the best reward over the test results;
* ``best_reward_std`` the standard deviation of best reward over the test results;
* ``train_step`` the total collected step of training collector;
* ``train_episode`` the total collected episode of training collector;
* ``test_step`` the total collected step of test collector;
* ``test_episode`` the total collected episode of test collector;
* ``timing`` the timing statistics, with the following members:
* ``total_time`` the total time elapsed;
* ``train_time`` the total time elapsed for learning training (collecting samples plus model update);
* ``train_time_collect`` the time for collecting transitions in the \
training collector;
* ``train_time_update`` the time for training models;
* ``test_time`` the time for testing;
* ``update_speed`` the speed of updating (env_step per second).
:return: InfoStats object with times computed based on the `start_time` and
episode/step counts read off the collectors. No computation of
expensive statistics is done here.
"""
duration = max(0.0, time.time() - start_time)
test_time = 0.0