Formatting

This commit is contained in:
Michael Panchenko 2024-05-05 15:16:16 +02:00
parent 6927eadaa7
commit f876198870
6 changed files with 34 additions and 16 deletions

View File

@ -250,7 +250,7 @@ def test_finite_dummy_vector_env() -> None:
envs.tracker = MetricTracker()
try:
# TODO: why on earth 10**18?
test_collector.collect(n_step=10 ** 18)
test_collector.collect(n_step=10**18)
except StopIteration:
envs.tracker.validate()
@ -265,6 +265,6 @@ def test_finite_subproc_vector_env() -> None:
for _ in range(3):
envs.tracker = MetricTracker()
try:
test_collector.collect(n_step=10 ** 18)
test_collector.collect(n_step=10**18)
except StopIteration:
envs.tracker.validate()

View File

@ -27,7 +27,7 @@ from tianshou.data.types import (
from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy
from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
from tianshou.utils.torch_utils import in_eval_mode
log = logging.getLogger(__name__)
@ -260,8 +260,16 @@ class BaseCollector(ABC):
) -> CollectStats:
pass
def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None,
no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats:
def collect(
self,
n_step: int | None = None,
n_episode: int | None = None,
random: bool = False,
render: float | None = None,
no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> CollectStats:
"""Collect a specified number of steps or episodes.
To ensure an unbiased sampling result with the n_episode option, this function will

View File

@ -130,7 +130,10 @@ class AgentFactory(ABC, ToStringMixin):
log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
)
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random)
train_collector.collect(
n_step=self.sampling_config.start_timesteps,
random=self.sampling_config.start_timesteps_random,
)
return train_collector, test_collector
def set_policy_wrapper_factory(

View File

@ -107,7 +107,11 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
) -> Batch:
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits_BA)
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
act_B = (
dist.mode
if self.deterministic_eval and not self.is_within_training_step
else dist.sample()
)
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:

View File

@ -197,7 +197,11 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
dist = self.dist_fn(action_dist_input_BD)
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
act_B = (
dist.mode
if self.deterministic_eval and not self.is_within_training_step
else dist.sample()
)
# act is of dimension BA in continuous case and of dimension B in discrete
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
return cast(DistBatchProtocol, result)

View File

@ -5,7 +5,6 @@ from collections import defaultdict, deque
from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import asdict
from typing import Optional, Tuple
import numpy as np
import tqdm
@ -305,7 +304,6 @@ class BaseTrainer(ABC):
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
train_stat: CollectStatsBase
while t.n < t.total and not self.stop_fn_flag:
train_stat, update_stat, self.stop_fn_flag = self.training_step()
if isinstance(train_stat, CollectStats):
@ -417,9 +415,8 @@ class BaseTrainer(ABC):
finally:
self.policy.is_within_training_step = old_value
def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]:
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
with self._is_within_training_step_enabled(True):
should_stop_training = False
if self.train_collector is not None:
@ -438,7 +435,7 @@ class BaseTrainer(ABC):
return collect_stats, training_stats, should_stop_training
def _collect_training_data(self) -> CollectStats:
"""Performs training data collection
"""Performs training data collection.
:return: the data collection stats
"""
@ -446,7 +443,10 @@ class BaseTrainer(ABC):
assert self.train_collector is not None
if self.train_fn:
self.train_fn(self.epoch, self.env_step)
collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect)
collect_stats = self.train_collector.collect(
n_step=self.step_per_collect,
n_episode=self.episode_per_collect,
)
self.env_step += collect_stats.n_collected_steps
@ -465,8 +465,7 @@ class BaseTrainer(ABC):
return collect_stats
def _test_in_train(self, collect_stats: CollectStats) -> bool:
"""
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
"""If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
on it.
Finally, if the latter is also True, will set should_stop_training to True.