Formatting
This commit is contained in:
parent
6927eadaa7
commit
f876198870
@ -27,7 +27,7 @@ from tianshou.data.types import (
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils.print import DataclassPPrintMixin
|
||||
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
|
||||
from tianshou.utils.torch_utils import in_eval_mode
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -260,8 +260,16 @@ class BaseCollector(ABC):
|
||||
) -> CollectStats:
|
||||
pass
|
||||
|
||||
def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None,
|
||||
no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats:
|
||||
def collect(
|
||||
self,
|
||||
n_step: int | None = None,
|
||||
n_episode: int | None = None,
|
||||
random: bool = False,
|
||||
render: float | None = None,
|
||||
no_grad: bool = True,
|
||||
reset_before_collect: bool = False,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
) -> CollectStats:
|
||||
"""Collect a specified number of steps or episodes.
|
||||
|
||||
To ensure an unbiased sampling result with the n_episode option, this function will
|
||||
|
@ -130,7 +130,10 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
log.info(
|
||||
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
|
||||
)
|
||||
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random)
|
||||
train_collector.collect(
|
||||
n_step=self.sampling_config.start_timesteps,
|
||||
random=self.sampling_config.start_timesteps_random,
|
||||
)
|
||||
return train_collector, test_collector
|
||||
|
||||
def set_policy_wrapper_factory(
|
||||
|
@ -107,7 +107,11 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
) -> Batch:
|
||||
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits_BA)
|
||||
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
|
||||
act_B = (
|
||||
dist.mode
|
||||
if self.deterministic_eval and not self.is_within_training_step
|
||||
else dist.sample()
|
||||
)
|
||||
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
|
@ -197,7 +197,11 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
|
||||
dist = self.dist_fn(action_dist_input_BD)
|
||||
|
||||
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
|
||||
act_B = (
|
||||
dist.mode
|
||||
if self.deterministic_eval and not self.is_within_training_step
|
||||
else dist.sample()
|
||||
)
|
||||
# act is of dimension BA in continuous case and of dimension B in discrete
|
||||
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
|
||||
return cast(DistBatchProtocol, result)
|
||||
|
@ -5,7 +5,6 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
@ -305,7 +304,6 @@ class BaseTrainer(ABC):
|
||||
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
|
||||
train_stat: CollectStatsBase
|
||||
while t.n < t.total and not self.stop_fn_flag:
|
||||
|
||||
train_stat, update_stat, self.stop_fn_flag = self.training_step()
|
||||
|
||||
if isinstance(train_stat, CollectStats):
|
||||
@ -417,9 +415,8 @@ class BaseTrainer(ABC):
|
||||
finally:
|
||||
self.policy.is_within_training_step = old_value
|
||||
|
||||
def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]:
|
||||
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
|
||||
with self._is_within_training_step_enabled(True):
|
||||
|
||||
should_stop_training = False
|
||||
|
||||
if self.train_collector is not None:
|
||||
@ -438,7 +435,7 @@ class BaseTrainer(ABC):
|
||||
return collect_stats, training_stats, should_stop_training
|
||||
|
||||
def _collect_training_data(self) -> CollectStats:
|
||||
"""Performs training data collection
|
||||
"""Performs training data collection.
|
||||
|
||||
:return: the data collection stats
|
||||
"""
|
||||
@ -446,7 +443,10 @@ class BaseTrainer(ABC):
|
||||
assert self.train_collector is not None
|
||||
if self.train_fn:
|
||||
self.train_fn(self.epoch, self.env_step)
|
||||
collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect)
|
||||
collect_stats = self.train_collector.collect(
|
||||
n_step=self.step_per_collect,
|
||||
n_episode=self.episode_per_collect,
|
||||
)
|
||||
|
||||
self.env_step += collect_stats.n_collected_steps
|
||||
|
||||
@ -465,8 +465,7 @@ class BaseTrainer(ABC):
|
||||
return collect_stats
|
||||
|
||||
def _test_in_train(self, collect_stats: CollectStats) -> bool:
|
||||
"""
|
||||
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
|
||||
"""If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
|
||||
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
|
||||
on it.
|
||||
Finally, if the latter is also True, will set should_stop_training to True.
|
||||
|
Loading…
x
Reference in New Issue
Block a user