add is_eval attribute to policy and set this attribute as well as train mode in appropriate places

This commit is contained in:
Maximilian Huettenrauch 2024-04-24 17:06:42 +02:00
parent ade85ab32b
commit e499bed8b0
17 changed files with 60 additions and 24 deletions

View File

@ -318,6 +318,7 @@ class Collector:
no_grad: bool = True, no_grad: bool = True,
reset_before_collect: bool = False, reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
is_eval: bool = False,
) -> CollectStats: ) -> CollectStats:
"""Collect a specified number of steps or episodes. """Collect a specified number of steps or episodes.
@ -334,6 +335,7 @@ class Collector:
(The collector needs the initial obs and info to function properly.) (The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Only used if reset_before_collect is True. reset function. Only used if reset_before_collect is True.
:param is_eval: whether to collect data in evaluation mode.
.. note:: .. note::
@ -356,6 +358,13 @@ class Collector:
# S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
# Only used in n_episode case. Then, R becomes R-S. # Only used in n_episode case. Then, R becomes R-S.
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
# policy.deterministic_eval)
self.policy.eval()
pre_collect_is_eval = self.policy.is_eval
self.policy.is_eval = is_eval
use_grad = not no_grad use_grad = not no_grad
gym_reset_kwargs = gym_reset_kwargs or {} gym_reset_kwargs = gym_reset_kwargs or {}
@ -568,6 +577,9 @@ class Collector:
# reset envs and the _pre_collect fields # reset envs and the _pre_collect fields
self.reset_env(gym_reset_kwargs) # todo still necessary? self.reset_env(gym_reset_kwargs) # todo still necessary?
# set the policy back to pre collect mode
self.policy.is_eval = pre_collect_is_eval
return CollectStats.with_autogenerated_stats( return CollectStats.with_autogenerated_stats(
returns=np.array(episode_returns), returns=np.array(episode_returns),
lens=np.array(episode_lens), lens=np.array(episode_lens),
@ -665,6 +677,7 @@ class AsyncCollector(Collector):
no_grad: bool = True, no_grad: bool = True,
reset_before_collect: bool = False, reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None, gym_reset_kwargs: dict[str, Any] | None = None,
is_eval: bool = False,
) -> CollectStats: ) -> CollectStats:
"""Collect a specified number of steps or episodes with async env setting. """Collect a specified number of steps or episodes with async env setting.
@ -686,6 +699,7 @@ class AsyncCollector(Collector):
(The collector needs the initial obs and info to function properly.) (The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments) reset function. Defaults to None (extra keyword arguments)
:param is_eval: whether to collect data in evaluation mode.
.. note:: .. note::
@ -694,6 +708,13 @@ class AsyncCollector(Collector):
:return: A dataclass object :return: A dataclass object
""" """
# set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy
# evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on
# policy.deterministic_eval)
self.policy.eval()
pre_collect_is_eval = self.policy.is_eval
self.policy.is_eval = is_eval
use_grad = not no_grad use_grad = not no_grad
gym_reset_kwargs = gym_reset_kwargs or {} gym_reset_kwargs = gym_reset_kwargs or {}
@ -902,6 +923,9 @@ class AsyncCollector(Collector):
# persist for future collect iterations # persist for future collect iterations
self._ready_env_ids_R = ready_env_ids_R self._ready_env_ids_R = ready_env_ids_R
# set the policy back to pre collect mode
self.policy.is_eval = pre_collect_is_eval
return CollectStats.with_autogenerated_stats( return CollectStats.with_autogenerated_stats(
returns=np.array(episode_returns), returns=np.array(episode_returns),
lens=np.array(episode_lens), lens=np.array(episode_lens),

View File

@ -335,10 +335,9 @@ class Experiment(ToStringMixin):
env: BaseVectorEnv, env: BaseVectorEnv,
render: float, render: float,
) -> None: ) -> None:
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector.reset() collector.reset()
result = collector.collect(n_episode=num_episodes, render=render) result = collector.collect(n_episode=num_episodes, render=render, is_eval=True)
assert result.returns_stat is not None # for mypy assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy assert result.lens_stat is not None # for mypy
log.info( log.info(

View File

@ -225,6 +225,8 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
self.action_scaling = action_scaling self.action_scaling = action_scaling
self.action_bound_method = action_bound_method self.action_bound_method = action_bound_method
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
# whether the policy is in evaluation mode
self.is_eval = False # TODO: remove in favor of kwarg in compute_action/forward?
self._compile() self._compile()
@property @property

View File

@ -165,6 +165,9 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> TA2CTrainingStats: ) -> TA2CTrainingStats:
# set policy in train mode
self.train()
losses, actor_losses, vf_losses, ent_losses = [], [], [], [] losses, actor_losses, vf_losses, ent_losses = [], [], [], []
split_batch_size = batch_size or -1 split_batch_size = batch_size or -1
for _ in range(repeat): for _ in range(repeat):

View File

@ -163,6 +163,9 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
return cast(ModelOutputBatchProtocol, result) return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0: if self._target and self._iter % self.freq == 0:
self.sync_weight() self.sync_weight()
self.optim.zero_grad() self.optim.zero_grad()

View File

@ -117,6 +117,8 @@ class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
return target_dist.sum(-1) return target_dist.sum(-1)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0: if self._target and self._iter % self.freq == 0:
self.sync_weight() self.sync_weight()
self.optim.zero_grad() self.optim.zero_grad()

View File

@ -107,10 +107,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
) -> Batch: ) -> Batch:
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits_BA) dist = Categorical(logits=logits_BA)
if self.deterministic_eval and not self.training: act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample()
act_B = dist.mode
else:
act_B = dist.sample()
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
@ -127,6 +124,9 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
return target_q.sum(dim=-1) + self.alpha * dist.entropy() return target_q.sum(dim=-1) + self.alpha * dist.entropy()
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore
# set policy in train mode
self.train()
weight = batch.pop("weight", 1.0) weight = batch.pop("weight", 1.0)
target_q = batch.returns.flatten() target_q = batch.returns.flatten()
act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long)

View File

@ -210,6 +210,8 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
return cast(ModelOutputBatchProtocol, result) return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0: if self._target and self._iter % self.freq == 0:
self.sync_weight() self.sync_weight()
self.optim.zero_grad() self.optim.zero_grad()

View File

@ -153,6 +153,8 @@ class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
return cast(FQFBatchProtocol, result) return cast(FQFBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0: if self._target and self._iter % self.freq == 0:
self.sync_weight() self.sync_weight()
weight = batch.pop("weight", 1.0) weight = batch.pop("weight", 1.0)

View File

@ -131,6 +131,8 @@ class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
return cast(QuantileRegressionBatchProtocol, result) return cast(QuantileRegressionBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0: if self._target and self._iter % self.freq == 0:
self.sync_weight() self.sync_weight()
self.optim.zero_grad() self.optim.zero_grad()

View File

@ -197,10 +197,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
dist = self.dist_fn(action_dist_input_BD) dist = self.dist_fn(action_dist_input_BD)
if self.deterministic_eval and not self.training: act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample()
act_B = dist.mode
else:
act_B = dist.sample()
# act is of dimension BA in continuous case and of dimension B in discrete # act is of dimension BA in continuous case and of dimension B in discrete
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
return cast(DistBatchProtocol, result) return cast(DistBatchProtocol, result)
@ -214,6 +211,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> TPGTrainingStats: ) -> TPGTrainingStats:
# set policy in train mode
self.train()
losses = [] losses = []
split_batch_size = batch_size or -1 split_batch_size = batch_size or -1
for _ in range(repeat): for _ in range(repeat):

View File

@ -151,6 +151,9 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> TPPOTrainingStats: ) -> TPPOTrainingStats:
# set policy in train mode
self.train()
losses, clip_losses, vf_losses, ent_losses = [], [], [], [] losses, clip_losses, vf_losses, ent_losses = [], [], [], []
split_batch_size = batch_size or -1 split_batch_size = batch_size or -1
for step in range(repeat): for step in range(repeat):

View File

@ -105,6 +105,8 @@ class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
return super().compute_q_value(logits.mean(2), mask) return super().compute_q_value(logits.mean(2), mask)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0: if self._target and self._iter % self.freq == 0:
self.sync_weight() self.sync_weight()
self.optim.zero_grad() self.optim.zero_grad()

View File

@ -153,7 +153,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
) -> Batch: ) -> Batch:
(loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc_B, scale_B), 1) dist = Independent(Normal(loc_B, scale_B), 1)
if self.deterministic_eval and not self.training: if self.deterministic_eval and self.is_eval:
act_B = dist.mode act_B = dist.mode
else: else:
act_B = dist.rsample() act_B = dist.rsample()

View File

@ -175,7 +175,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
) -> DistLogProbBatchProtocol: ) -> DistLogProbBatchProtocol:
(loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
if self.deterministic_eval and not self.training: if self.deterministic_eval and self.is_eval:
act_B = dist.mode act_B = dist.mode
else: else:
act_B = dist.rsample() act_B = dist.rsample()
@ -213,6 +213,9 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
) )
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore
# set policy in train mode
self.train()
# critic 1&2 # critic 1&2
td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim)

View File

@ -269,7 +269,6 @@ class BaseTrainer(ABC):
assert self.episode_per_test is not None assert self.episode_per_test is not None
assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 assert not isinstance(self.test_collector, AsyncCollector) # Issue 700
test_result = test_episode( test_result = test_episode(
self.policy,
self.test_collector, self.test_collector,
self.test_fn, self.test_fn,
self.start_epoch, self.start_epoch,
@ -309,9 +308,6 @@ class BaseTrainer(ABC):
if self.stop_fn_flag: if self.stop_fn_flag:
raise StopIteration raise StopIteration
# set policy in train mode
self.policy.train()
progress = tqdm.tqdm if self.show_progress else DummyTqdm progress = tqdm.tqdm if self.show_progress else DummyTqdm
# perform n step_per_epoch # perform n step_per_epoch
@ -395,7 +391,6 @@ class BaseTrainer(ABC):
assert self.test_collector is not None assert self.test_collector is not None
stop_fn_flag = False stop_fn_flag = False
test_stat = test_episode( test_stat = test_episode(
self.policy,
self.test_collector, self.test_collector,
self.test_fn, self.test_fn,
self.epoch, self.epoch,
@ -468,7 +463,6 @@ class BaseTrainer(ABC):
): ):
assert self.test_collector is not None assert self.test_collector is not None
test_result = test_episode( test_result = test_episode(
self.policy,
self.test_collector, self.test_collector,
self.test_fn, self.test_fn,
self.epoch, self.epoch,
@ -481,8 +475,6 @@ class BaseTrainer(ABC):
should_stop_training = True should_stop_training = True
self.best_reward = test_result.returns_stat.mean self.best_reward = test_result.returns_stat.mean
self.best_reward_std = test_result.returns_stat.std self.best_reward_std = test_result.returns_stat.std
else:
self.policy.train()
return result, should_stop_training return result, should_stop_training
# TODO: move moving average computation and logging into its own logger # TODO: move moving average computation and logging into its own logger

View File

@ -11,12 +11,10 @@ from tianshou.data import (
SequenceSummaryStats, SequenceSummaryStats,
TimingStats, TimingStats,
) )
from tianshou.policy import BasePolicy
from tianshou.utils import BaseLogger from tianshou.utils import BaseLogger
def test_episode( def test_episode(
policy: BasePolicy,
collector: Collector, collector: Collector,
test_fn: Callable[[int, int | None], None] | None, test_fn: Callable[[int, int | None], None] | None,
epoch: int, epoch: int,
@ -27,10 +25,9 @@ def test_episode(
) -> CollectStats: ) -> CollectStats:
"""A simple wrapper of testing policy in collector.""" """A simple wrapper of testing policy in collector."""
collector.reset(reset_stats=False) collector.reset(reset_stats=False)
policy.eval()
if test_fn: if test_fn:
test_fn(epoch, global_step) test_fn(epoch, global_step)
result = collector.collect(n_episode=n_episode) result = collector.collect(n_episode=n_episode, is_eval=True)
if reward_metric: # TODO: move into collector if reward_metric: # TODO: move into collector
rew = reward_metric(result.returns) rew = reward_metric(result.returns)
result.returns = rew result.returns = rew