Formatting

This commit is contained in:
Michael Panchenko 2024-05-05 15:16:16 +02:00
parent 6927eadaa7
commit f876198870
6 changed files with 34 additions and 16 deletions

View File

@ -250,7 +250,7 @@ def test_finite_dummy_vector_env() -> None:
envs.tracker = MetricTracker() envs.tracker = MetricTracker()
try: try:
# TODO: why on earth 10**18? # TODO: why on earth 10**18?
test_collector.collect(n_step=10 ** 18) test_collector.collect(n_step=10**18)
except StopIteration: except StopIteration:
envs.tracker.validate() envs.tracker.validate()
@ -265,6 +265,6 @@ def test_finite_subproc_vector_env() -> None:
for _ in range(3): for _ in range(3):
envs.tracker = MetricTracker() envs.tracker = MetricTracker()
try: try:
test_collector.collect(n_step=10 ** 18) test_collector.collect(n_step=10**18)
except StopIteration: except StopIteration:
envs.tracker.validate() envs.tracker.validate()

View File

@ -27,7 +27,7 @@ from tianshou.data.types import (
from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode from tianshou.utils.torch_utils import in_eval_mode
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -260,8 +260,16 @@ class BaseCollector(ABC):
) -> CollectStats: ) -> CollectStats:
pass pass
def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, def collect(
no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats: self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
"""Collect a specified number of steps or episodes. """Collect a specified number of steps or episodes.
To ensure an unbiased sampling result with the n_episode option, this function will To ensure an unbiased sampling result with the n_episode option, this function will

View File

@ -130,7 +130,10 @@ class AgentFactory(ABC, ToStringMixin):
log.info( log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
) )
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random) train_collector.collect(
n_step=self.sampling_config.start_timesteps,
random=self.sampling_config.start_timesteps_random,
)
return train_collector, test_collector return train_collector, test_collector
def set_policy_wrapper_factory( def set_policy_wrapper_factory(

View File

@ -107,7 +107,11 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
) -> Batch: ) -> Batch:
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits_BA) dist = Categorical(logits=logits_BA)
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() act_B = (
dist.mode
if self.deterministic_eval and not self.is_within_training_step
else dist.sample()
)
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:

View File

@ -197,7 +197,11 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
dist = self.dist_fn(action_dist_input_BD) dist = self.dist_fn(action_dist_input_BD)
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() act_B = (
dist.mode
if self.deterministic_eval and not self.is_within_training_step
else dist.sample()
)
# act is of dimension BA in continuous case and of dimension B in discrete # act is of dimension BA in continuous case and of dimension B in discrete
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
return cast(DistBatchProtocol, result) return cast(DistBatchProtocol, result)

View File

@ -5,7 +5,6 @@ from collections import defaultdict, deque
from collections.abc import Callable from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import asdict from dataclasses import asdict
from typing import Optional, Tuple
import numpy as np import numpy as np
import tqdm import tqdm
@ -305,7 +304,6 @@ class BaseTrainer(ABC):
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
train_stat: CollectStatsBase train_stat: CollectStatsBase
while t.n < t.total and not self.stop_fn_flag: while t.n < t.total and not self.stop_fn_flag:
train_stat, update_stat, self.stop_fn_flag = self.training_step() train_stat, update_stat, self.stop_fn_flag = self.training_step()
if isinstance(train_stat, CollectStats): if isinstance(train_stat, CollectStats):
@ -417,9 +415,8 @@ class BaseTrainer(ABC):
finally: finally:
self.policy.is_within_training_step = old_value self.policy.is_within_training_step = old_value
def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]: def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
with self._is_within_training_step_enabled(True): with self._is_within_training_step_enabled(True):
should_stop_training = False should_stop_training = False
if self.train_collector is not None: if self.train_collector is not None:
@ -438,7 +435,7 @@ class BaseTrainer(ABC):
return collect_stats, training_stats, should_stop_training return collect_stats, training_stats, should_stop_training
def _collect_training_data(self) -> CollectStats: def _collect_training_data(self) -> CollectStats:
"""Performs training data collection """Performs training data collection.
:return: the data collection stats :return: the data collection stats
""" """
@ -446,7 +443,10 @@ class BaseTrainer(ABC):
assert self.train_collector is not None assert self.train_collector is not None
if self.train_fn: if self.train_fn:
self.train_fn(self.epoch, self.env_step) self.train_fn(self.epoch, self.env_step)
collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect) collect_stats = self.train_collector.collect(
n_step=self.step_per_collect,
n_episode=self.episode_per_collect,
)
self.env_step += collect_stats.n_collected_steps self.env_step += collect_stats.n_collected_steps
@ -465,8 +465,7 @@ class BaseTrainer(ABC):
return collect_stats return collect_stats
def _test_in_train(self, collect_stats: CollectStats) -> bool: def _test_in_train(self, collect_stats: CollectStats) -> bool:
""" """If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
on it. on it.
Finally, if the latter is also True, will set should_stop_training to True. Finally, if the latter is also True, will set should_stop_training to True.