diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 35bf3e2..ce8a936 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -250,7 +250,7 @@ def test_finite_dummy_vector_env() -> None: envs.tracker = MetricTracker() try: # TODO: why on earth 10**18? - test_collector.collect(n_step=10 ** 18) + test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() @@ -265,6 +265,6 @@ def test_finite_subproc_vector_env() -> None: for _ in range(3): envs.tracker = MetricTracker() try: - test_collector.collect(n_step=10 ** 18) + test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index cf897a5..133667f 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -27,7 +27,7 @@ from tianshou.data.types import ( from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.utils.print import DataclassPPrintMixin -from tianshou.utils.torch_utils import in_eval_mode, in_train_mode +from tianshou.utils.torch_utils import in_eval_mode log = logging.getLogger(__name__) @@ -260,8 +260,16 @@ class BaseCollector(ABC): ) -> CollectStats: pass - def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, - no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats: + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: """Collect a specified number of steps or episodes. To ensure an unbiased sampling result with the n_episode option, this function will diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index fdfc4c0..c131326 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -130,7 +130,10 @@ class AgentFactory(ABC, ToStringMixin): log.info( f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", ) - train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random) + train_collector.collect( + n_step=self.sampling_config.start_timesteps, + random=self.sampling_config.start_timesteps_random, + ) return train_collector, test_collector def set_policy_wrapper_factory( diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 8c79428..d1ce28d 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -107,7 +107,11 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): ) -> Batch: logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits_BA) - act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 3ef82be..80bcff6 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -197,7 +197,11 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked dist = self.dist_fn(action_dist_input_BD) - act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) # act is of dimension BA in continuous case and of dimension B in discrete result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 213da7b..6e46374 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -5,7 +5,6 @@ from collections import defaultdict, deque from collections.abc import Callable from contextlib import contextmanager from dataclasses import asdict -from typing import Optional, Tuple import numpy as np import tqdm @@ -305,7 +304,6 @@ class BaseTrainer(ABC): with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: train_stat: CollectStatsBase while t.n < t.total and not self.stop_fn_flag: - train_stat, update_stat, self.stop_fn_flag = self.training_step() if isinstance(train_stat, CollectStats): @@ -417,9 +415,8 @@ class BaseTrainer(ABC): finally: self.policy.is_within_training_step = old_value - def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]: + def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: with self._is_within_training_step_enabled(True): - should_stop_training = False if self.train_collector is not None: @@ -438,7 +435,7 @@ class BaseTrainer(ABC): return collect_stats, training_stats, should_stop_training def _collect_training_data(self) -> CollectStats: - """Performs training data collection + """Performs training data collection. :return: the data collection stats """ @@ -446,7 +443,10 @@ class BaseTrainer(ABC): assert self.train_collector is not None if self.train_fn: self.train_fn(self.epoch, self.env_step) - collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect) + collect_stats = self.train_collector.collect( + n_step=self.step_per_collect, + n_episode=self.episode_per_collect, + ) self.env_step += collect_stats.n_collected_steps @@ -465,8 +465,7 @@ class BaseTrainer(ABC): return collect_stats def _test_in_train(self, collect_stats: CollectStats) -> bool: - """ - If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. + """If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data. Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return on it. Finally, if the latter is also True, will set should_stop_training to True.