Formatting
This commit is contained in:
parent
6927eadaa7
commit
f876198870
@ -27,7 +27,7 @@ from tianshou.data.types import (
|
|||||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.utils.print import DataclassPPrintMixin
|
from tianshou.utils.print import DataclassPPrintMixin
|
||||||
from tianshou.utils.torch_utils import in_eval_mode, in_train_mode
|
from tianshou.utils.torch_utils import in_eval_mode
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -260,8 +260,16 @@ class BaseCollector(ABC):
|
|||||||
) -> CollectStats:
|
) -> CollectStats:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None,
|
def collect(
|
||||||
no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats:
|
self,
|
||||||
|
n_step: int | None = None,
|
||||||
|
n_episode: int | None = None,
|
||||||
|
random: bool = False,
|
||||||
|
render: float | None = None,
|
||||||
|
no_grad: bool = True,
|
||||||
|
reset_before_collect: bool = False,
|
||||||
|
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||||
|
) -> CollectStats:
|
||||||
"""Collect a specified number of steps or episodes.
|
"""Collect a specified number of steps or episodes.
|
||||||
|
|
||||||
To ensure an unbiased sampling result with the n_episode option, this function will
|
To ensure an unbiased sampling result with the n_episode option, this function will
|
||||||
|
@ -130,7 +130,10 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
log.info(
|
log.info(
|
||||||
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
|
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
|
||||||
)
|
)
|
||||||
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random)
|
train_collector.collect(
|
||||||
|
n_step=self.sampling_config.start_timesteps,
|
||||||
|
random=self.sampling_config.start_timesteps_random,
|
||||||
|
)
|
||||||
return train_collector, test_collector
|
return train_collector, test_collector
|
||||||
|
|
||||||
def set_policy_wrapper_factory(
|
def set_policy_wrapper_factory(
|
||||||
|
@ -107,7 +107,11 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
|||||||
) -> Batch:
|
) -> Batch:
|
||||||
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||||
dist = Categorical(logits=logits_BA)
|
dist = Categorical(logits=logits_BA)
|
||||||
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
|
act_B = (
|
||||||
|
dist.mode
|
||||||
|
if self.deterministic_eval and not self.is_within_training_step
|
||||||
|
else dist.sample()
|
||||||
|
)
|
||||||
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
|
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
|
||||||
|
|
||||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||||
|
@ -197,7 +197,11 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
|||||||
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
|
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
|
||||||
dist = self.dist_fn(action_dist_input_BD)
|
dist = self.dist_fn(action_dist_input_BD)
|
||||||
|
|
||||||
act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample()
|
act_B = (
|
||||||
|
dist.mode
|
||||||
|
if self.deterministic_eval and not self.is_within_training_step
|
||||||
|
else dist.sample()
|
||||||
|
)
|
||||||
# act is of dimension BA in continuous case and of dimension B in discrete
|
# act is of dimension BA in continuous case and of dimension B in discrete
|
||||||
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
|
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
|
||||||
return cast(DistBatchProtocol, result)
|
return cast(DistBatchProtocol, result)
|
||||||
|
@ -5,7 +5,6 @@ from collections import defaultdict, deque
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -305,7 +304,6 @@ class BaseTrainer(ABC):
|
|||||||
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
|
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
|
||||||
train_stat: CollectStatsBase
|
train_stat: CollectStatsBase
|
||||||
while t.n < t.total and not self.stop_fn_flag:
|
while t.n < t.total and not self.stop_fn_flag:
|
||||||
|
|
||||||
train_stat, update_stat, self.stop_fn_flag = self.training_step()
|
train_stat, update_stat, self.stop_fn_flag = self.training_step()
|
||||||
|
|
||||||
if isinstance(train_stat, CollectStats):
|
if isinstance(train_stat, CollectStats):
|
||||||
@ -417,9 +415,8 @@ class BaseTrainer(ABC):
|
|||||||
finally:
|
finally:
|
||||||
self.policy.is_within_training_step = old_value
|
self.policy.is_within_training_step = old_value
|
||||||
|
|
||||||
def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]:
|
def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]:
|
||||||
with self._is_within_training_step_enabled(True):
|
with self._is_within_training_step_enabled(True):
|
||||||
|
|
||||||
should_stop_training = False
|
should_stop_training = False
|
||||||
|
|
||||||
if self.train_collector is not None:
|
if self.train_collector is not None:
|
||||||
@ -438,7 +435,7 @@ class BaseTrainer(ABC):
|
|||||||
return collect_stats, training_stats, should_stop_training
|
return collect_stats, training_stats, should_stop_training
|
||||||
|
|
||||||
def _collect_training_data(self) -> CollectStats:
|
def _collect_training_data(self) -> CollectStats:
|
||||||
"""Performs training data collection
|
"""Performs training data collection.
|
||||||
|
|
||||||
:return: the data collection stats
|
:return: the data collection stats
|
||||||
"""
|
"""
|
||||||
@ -446,7 +443,10 @@ class BaseTrainer(ABC):
|
|||||||
assert self.train_collector is not None
|
assert self.train_collector is not None
|
||||||
if self.train_fn:
|
if self.train_fn:
|
||||||
self.train_fn(self.epoch, self.env_step)
|
self.train_fn(self.epoch, self.env_step)
|
||||||
collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect)
|
collect_stats = self.train_collector.collect(
|
||||||
|
n_step=self.step_per_collect,
|
||||||
|
n_episode=self.episode_per_collect,
|
||||||
|
)
|
||||||
|
|
||||||
self.env_step += collect_stats.n_collected_steps
|
self.env_step += collect_stats.n_collected_steps
|
||||||
|
|
||||||
@ -465,8 +465,7 @@ class BaseTrainer(ABC):
|
|||||||
return collect_stats
|
return collect_stats
|
||||||
|
|
||||||
def _test_in_train(self, collect_stats: CollectStats) -> bool:
|
def _test_in_train(self, collect_stats: CollectStats) -> bool:
|
||||||
"""
|
"""If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
|
||||||
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
|
|
||||||
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
|
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
|
||||||
on it.
|
on it.
|
||||||
Finally, if the latter is also True, will set should_stop_training to True.
|
Finally, if the latter is also True, will set should_stop_training to True.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user