add is_eval attribute to policy and set this attribute as well as train mode in appropriate places
This commit is contained in:
parent
ade85ab32b
commit
e499bed8b0
@ -318,6 +318,7 @@ class Collector:
|
||||
no_grad: bool = True,
|
||||
reset_before_collect: bool = False,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
is_eval: bool = False,
|
||||
) -> CollectStats:
|
||||
"""Collect a specified number of steps or episodes.
|
||||
|
||||
@ -334,6 +335,7 @@ class Collector:
|
||||
(The collector needs the initial obs and info to function properly.)
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Only used if reset_before_collect is True.
|
||||
:param is_eval: whether to collect data in evaluation mode.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -356,6 +358,13 @@ class Collector:
|
||||
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
|
||||
# Only used in n_episode case. Then, R becomes R-S.
|
||||
|
||||
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
|
||||
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
|
||||
# policy.deterministic_eval)
|
||||
self.policy.eval()
|
||||
pre_collect_is_eval = self.policy.is_eval
|
||||
self.policy.is_eval = is_eval
|
||||
|
||||
use_grad = not no_grad
|
||||
gym_reset_kwargs = gym_reset_kwargs or {}
|
||||
|
||||
@ -568,6 +577,9 @@ class Collector:
|
||||
# reset envs and the _pre_collect fields
|
||||
self.reset_env(gym_reset_kwargs) # todo still necessary?
|
||||
|
||||
# set the policy back to pre collect mode
|
||||
self.policy.is_eval = pre_collect_is_eval
|
||||
|
||||
return CollectStats.with_autogenerated_stats(
|
||||
returns=np.array(episode_returns),
|
||||
lens=np.array(episode_lens),
|
||||
@ -665,6 +677,7 @@ class AsyncCollector(Collector):
|
||||
no_grad: bool = True,
|
||||
reset_before_collect: bool = False,
|
||||
gym_reset_kwargs: dict[str, Any] | None = None,
|
||||
is_eval: bool = False,
|
||||
) -> CollectStats:
|
||||
"""Collect a specified number of steps or episodes with async env setting.
|
||||
|
||||
@ -686,6 +699,7 @@ class AsyncCollector(Collector):
|
||||
(The collector needs the initial obs and info to function properly.)
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Defaults to None (extra keyword arguments)
|
||||
:param is_eval: whether to collect data in evaluation mode.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -694,6 +708,13 @@ class AsyncCollector(Collector):
|
||||
|
||||
:return: A dataclass object
|
||||
"""
|
||||
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
|
||||
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
|
||||
# policy.deterministic_eval)
|
||||
self.policy.eval()
|
||||
pre_collect_is_eval = self.policy.is_eval
|
||||
self.policy.is_eval = is_eval
|
||||
|
||||
use_grad = not no_grad
|
||||
gym_reset_kwargs = gym_reset_kwargs or {}
|
||||
|
||||
@ -902,6 +923,9 @@ class AsyncCollector(Collector):
|
||||
# persist for future collect iterations
|
||||
self._ready_env_ids_R = ready_env_ids_R
|
||||
|
||||
# set the policy back to pre collect mode
|
||||
self.policy.is_eval = pre_collect_is_eval
|
||||
|
||||
return CollectStats.with_autogenerated_stats(
|
||||
returns=np.array(episode_returns),
|
||||
lens=np.array(episode_lens),
|
||||
|
@ -335,10 +335,9 @@ class Experiment(ToStringMixin):
|
||||
env: BaseVectorEnv,
|
||||
render: float,
|
||||
) -> None:
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
result = collector.collect(n_episode=num_episodes, render=render)
|
||||
result = collector.collect(n_episode=num_episodes, render=render, is_eval=True)
|
||||
assert result.returns_stat is not None # for mypy
|
||||
assert result.lens_stat is not None # for mypy
|
||||
log.info(
|
||||
|
@ -225,6 +225,8 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
self.action_scaling = action_scaling
|
||||
self.action_bound_method = action_bound_method
|
||||
self.lr_scheduler = lr_scheduler
|
||||
# whether the policy is in evaluation mode
|
||||
self.is_eval = False # TODO: remove in favor of kwarg in compute_action/forward?
|
||||
self._compile()
|
||||
|
||||
@property
|
||||
|
@ -165,6 +165,9 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> TA2CTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
split_batch_size = batch_size or -1
|
||||
for _ in range(repeat):
|
||||
|
@ -163,6 +163,9 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -117,6 +117,8 @@ class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
|
||||
return target_dist.sum(-1)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -107,10 +107,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
) -> Batch:
|
||||
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits_BA)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act_B = dist.sample()
|
||||
act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample()
|
||||
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
@ -127,6 +124,9 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
return target_q.sum(dim=-1) + self.alpha * dist.entropy()
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
weight = batch.pop("weight", 1.0)
|
||||
target_q = batch.returns.flatten()
|
||||
act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long)
|
||||
|
@ -210,6 +210,8 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -153,6 +153,8 @@ class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
|
||||
return cast(FQFBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
weight = batch.pop("weight", 1.0)
|
||||
|
@ -131,6 +131,8 @@ class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
|
||||
return cast(QuantileRegressionBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -197,10 +197,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
|
||||
dist = self.dist_fn(action_dist_input_BD)
|
||||
|
||||
if self.deterministic_eval and not self.training:
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act_B = dist.sample()
|
||||
act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample()
|
||||
# act is of dimension BA in continuous case and of dimension B in discrete
|
||||
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
|
||||
return cast(DistBatchProtocol, result)
|
||||
@ -214,6 +211,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> TPGTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
losses = []
|
||||
split_batch_size = batch_size or -1
|
||||
for _ in range(repeat):
|
||||
|
@ -151,6 +151,9 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> TPPOTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
split_batch_size = batch_size or -1
|
||||
for step in range(repeat):
|
||||
|
@ -105,6 +105,8 @@ class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
|
||||
return super().compute_q_value(logits.mean(2), mask)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -153,7 +153,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
|
||||
) -> Batch:
|
||||
(loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Independent(Normal(loc_B, scale_B), 1)
|
||||
if self.deterministic_eval and not self.training:
|
||||
if self.deterministic_eval and self.is_eval:
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act_B = dist.rsample()
|
||||
|
@ -175,7 +175,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
) -> DistLogProbBatchProtocol:
|
||||
(loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
|
||||
if self.deterministic_eval and not self.training:
|
||||
if self.deterministic_eval and self.is_eval:
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act_B = dist.rsample()
|
||||
@ -213,6 +213,9 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
# critic 1&2
|
||||
td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
|
||||
td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim)
|
||||
|
@ -269,7 +269,6 @@ class BaseTrainer(ABC):
|
||||
assert self.episode_per_test is not None
|
||||
assert not isinstance(self.test_collector, AsyncCollector) # Issue 700
|
||||
test_result = test_episode(
|
||||
self.policy,
|
||||
self.test_collector,
|
||||
self.test_fn,
|
||||
self.start_epoch,
|
||||
@ -309,9 +308,6 @@ class BaseTrainer(ABC):
|
||||
if self.stop_fn_flag:
|
||||
raise StopIteration
|
||||
|
||||
# set policy in train mode
|
||||
self.policy.train()
|
||||
|
||||
progress = tqdm.tqdm if self.show_progress else DummyTqdm
|
||||
|
||||
# perform n step_per_epoch
|
||||
@ -395,7 +391,6 @@ class BaseTrainer(ABC):
|
||||
assert self.test_collector is not None
|
||||
stop_fn_flag = False
|
||||
test_stat = test_episode(
|
||||
self.policy,
|
||||
self.test_collector,
|
||||
self.test_fn,
|
||||
self.epoch,
|
||||
@ -468,7 +463,6 @@ class BaseTrainer(ABC):
|
||||
):
|
||||
assert self.test_collector is not None
|
||||
test_result = test_episode(
|
||||
self.policy,
|
||||
self.test_collector,
|
||||
self.test_fn,
|
||||
self.epoch,
|
||||
@ -481,8 +475,6 @@ class BaseTrainer(ABC):
|
||||
should_stop_training = True
|
||||
self.best_reward = test_result.returns_stat.mean
|
||||
self.best_reward_std = test_result.returns_stat.std
|
||||
else:
|
||||
self.policy.train()
|
||||
return result, should_stop_training
|
||||
|
||||
# TODO: move moving average computation and logging into its own logger
|
||||
|
@ -11,12 +11,10 @@ from tianshou.data import (
|
||||
SequenceSummaryStats,
|
||||
TimingStats,
|
||||
)
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils import BaseLogger
|
||||
|
||||
|
||||
def test_episode(
|
||||
policy: BasePolicy,
|
||||
collector: Collector,
|
||||
test_fn: Callable[[int, int | None], None] | None,
|
||||
epoch: int,
|
||||
@ -27,10 +25,9 @@ def test_episode(
|
||||
) -> CollectStats:
|
||||
"""A simple wrapper of testing policy in collector."""
|
||||
collector.reset(reset_stats=False)
|
||||
policy.eval()
|
||||
if test_fn:
|
||||
test_fn(epoch, global_step)
|
||||
result = collector.collect(n_episode=n_episode)
|
||||
result = collector.collect(n_episode=n_episode, is_eval=True)
|
||||
if reward_metric: # TODO: move into collector
|
||||
rew = reward_metric(result.returns)
|
||||
result.returns = rew
|
||||
|
Loading…
x
Reference in New Issue
Block a user