Improved docstrings, added asserts to make mypy happy
This commit is contained in:
parent
c5d0e169b5
commit
26a6cca76e
@ -279,7 +279,6 @@ class BaseCollector(ABC):
|
||||
:param n_episode: how many episodes you want to collect.
|
||||
:param random: whether to use random policy for collecting data.
|
||||
:param render: the sleep time between rendering consecutive frames.
|
||||
:param no_grad: whether to retain gradient in policy.forward().
|
||||
:param reset_before_collect: whether to reset the environment before collecting data.
|
||||
(The collector needs the initial obs and info to function properly.)
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
@ -343,11 +342,6 @@ class Collector(BaseCollector):
|
||||
# D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
|
||||
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
|
||||
# Only used in n_episode case. Then, R becomes R-S.
|
||||
|
||||
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
|
||||
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
|
||||
# policy.deterministic_eval)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict
|
||||
|
||||
@ -360,7 +360,10 @@ class BaseTrainer(ABC):
|
||||
self.logger.log_info_data(asdict(info_stat), self.epoch)
|
||||
|
||||
# in case trainer is used with run(), epoch_stat will not be returned
|
||||
epoch_stat: EpochStats = EpochStats(
|
||||
assert (
|
||||
update_stat is not None
|
||||
), "Defined in the loop above, this shouldn't have happened and is likely a bug!"
|
||||
return EpochStats(
|
||||
epoch=self.epoch,
|
||||
train_collect_stat=train_stat,
|
||||
test_collect_stat=test_stat,
|
||||
@ -368,8 +371,6 @@ class BaseTrainer(ABC):
|
||||
info_stat=info_stat,
|
||||
)
|
||||
|
||||
return epoch_stat
|
||||
|
||||
def test_step(self) -> tuple[CollectStats, bool]:
|
||||
"""Perform one testing step."""
|
||||
assert self.episode_per_test is not None
|
||||
@ -407,7 +408,7 @@ class BaseTrainer(ABC):
|
||||
return test_stat, stop_fn_flag
|
||||
|
||||
@contextmanager
|
||||
def _is_within_training_step_enabled(self, is_within_training_step: bool):
|
||||
def _is_within_training_step_enabled(self, is_within_training_step: bool) -> Iterator[None]:
|
||||
old_value = self.policy.is_within_training_step
|
||||
try:
|
||||
self.policy.is_within_training_step = is_within_training_step
|
||||
@ -419,10 +420,12 @@ class BaseTrainer(ABC):
|
||||
with self._is_within_training_step_enabled(True):
|
||||
should_stop_training = False
|
||||
|
||||
collect_stats: CollectStatsBase | CollectStats
|
||||
if self.train_collector is not None:
|
||||
collect_stats = self._collect_training_data()
|
||||
should_stop_training = self._test_in_train(collect_stats)
|
||||
else:
|
||||
assert self.buffer is not None, "Either train_collector or buffer must be provided."
|
||||
collect_stats = CollectStatsBase(
|
||||
n_collected_episodes=len(self.buffer),
|
||||
)
|
||||
@ -484,6 +487,7 @@ class BaseTrainer(ABC):
|
||||
and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore
|
||||
):
|
||||
assert self.test_collector is not None
|
||||
assert self.episode_per_test is not None and self.episode_per_test > 0
|
||||
test_result = test_episode(
|
||||
self.test_collector,
|
||||
self.test_fn,
|
||||
|
@ -49,23 +49,9 @@ def gather_info(
|
||||
) -> InfoStats:
|
||||
"""A simple wrapper of gathering information from collectors.
|
||||
|
||||
:return: A dataclass object with the following members (depending on available collectors):
|
||||
|
||||
* ``gradient_step`` the total number of gradient steps;
|
||||
* ``best_reward`` the best reward over the test results;
|
||||
* ``best_reward_std`` the standard deviation of best reward over the test results;
|
||||
* ``train_step`` the total collected step of training collector;
|
||||
* ``train_episode`` the total collected episode of training collector;
|
||||
* ``test_step`` the total collected step of test collector;
|
||||
* ``test_episode`` the total collected episode of test collector;
|
||||
* ``timing`` the timing statistics, with the following members:
|
||||
* ``total_time`` the total time elapsed;
|
||||
* ``train_time`` the total time elapsed for learning training (collecting samples plus model update);
|
||||
* ``train_time_collect`` the time for collecting transitions in the \
|
||||
training collector;
|
||||
* ``train_time_update`` the time for training models;
|
||||
* ``test_time`` the time for testing;
|
||||
* ``update_speed`` the speed of updating (env_step per second).
|
||||
:return: InfoStats object with times computed based on the `start_time` and
|
||||
episode/step counts read off the collectors. No computation of
|
||||
expensive statistics is done here.
|
||||
"""
|
||||
duration = max(0.0, time.time() - start_time)
|
||||
test_time = 0.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user