From e499bed8b031592c4da5c4f897ab1ffeec74c03a Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Wed, 24 Apr 2024 17:06:42 +0200 Subject: [PATCH] add is_eval attribute to policy and set this attribute as well as train mode in appropriate places --- tianshou/data/collector.py | 24 +++++++++++++++++++++++ tianshou/highlevel/experiment.py | 3 +-- tianshou/policy/base.py | 2 ++ tianshou/policy/modelfree/a2c.py | 3 +++ tianshou/policy/modelfree/bdq.py | 3 +++ tianshou/policy/modelfree/c51.py | 2 ++ tianshou/policy/modelfree/discrete_sac.py | 8 ++++---- tianshou/policy/modelfree/dqn.py | 2 ++ tianshou/policy/modelfree/fqf.py | 2 ++ tianshou/policy/modelfree/iqn.py | 2 ++ tianshou/policy/modelfree/pg.py | 8 ++++---- tianshou/policy/modelfree/ppo.py | 3 +++ tianshou/policy/modelfree/qrdqn.py | 2 ++ tianshou/policy/modelfree/redq.py | 2 +- tianshou/policy/modelfree/sac.py | 5 ++++- tianshou/trainer/base.py | 8 -------- tianshou/trainer/utils.py | 5 +---- 17 files changed, 60 insertions(+), 24 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 345d50b..1a60bd0 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -318,6 +318,7 @@ class Collector: no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, + is_eval: bool = False, ) -> CollectStats: """Collect a specified number of steps or episodes. @@ -334,6 +335,7 @@ class Collector: (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. + :param is_eval: whether to collect data in evaluation mode. .. note:: @@ -356,6 +358,13 @@ class Collector: # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. + # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy + # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on + # policy.deterministic_eval) + self.policy.eval() + pre_collect_is_eval = self.policy.is_eval + self.policy.is_eval = is_eval + use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} @@ -568,6 +577,9 @@ class Collector: # reset envs and the _pre_collect fields self.reset_env(gym_reset_kwargs) # todo still necessary? + # set the policy back to pre collect mode + self.policy.is_eval = pre_collect_is_eval + return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), @@ -665,6 +677,7 @@ class AsyncCollector(Collector): no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, + is_eval: bool = False, ) -> CollectStats: """Collect a specified number of steps or episodes with async env setting. @@ -686,6 +699,7 @@ class AsyncCollector(Collector): (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) + :param is_eval: whether to collect data in evaluation mode. .. note:: @@ -694,6 +708,13 @@ class AsyncCollector(Collector): :return: A dataclass object """ + # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy + # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on + # policy.deterministic_eval) + self.policy.eval() + pre_collect_is_eval = self.policy.is_eval + self.policy.is_eval = is_eval + use_grad = not no_grad gym_reset_kwargs = gym_reset_kwargs or {} @@ -902,6 +923,9 @@ class AsyncCollector(Collector): # persist for future collect iterations self._ready_env_ids_R = ready_env_ids_R + # set the policy back to pre collect mode + self.policy.is_eval = pre_collect_is_eval + return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 6f9eb7c..71b8159 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -335,10 +335,9 @@ class Experiment(ToStringMixin): env: BaseVectorEnv, render: float, ) -> None: - policy.eval() collector = Collector(policy, env) collector.reset() - result = collector.collect(n_episode=num_episodes, render=render) + result = collector.collect(n_episode=num_episodes, render=render, is_eval=True) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info( diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 77602a0..450bf52 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -225,6 +225,8 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): self.action_scaling = action_scaling self.action_bound_method = action_bound_method self.lr_scheduler = lr_scheduler + # whether the policy is in evaluation mode + self.is_eval = False # TODO: remove in favor of kwarg in compute_action/forward? self._compile() @property diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index d41ccb4..f055c41 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -165,6 +165,9 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ *args: Any, **kwargs: Any, ) -> TA2CTrainingStats: + # set policy in train mode + self.train() + losses, actor_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index d7196a9..80c17be 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -163,6 +163,9 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: + # set policy in train mode + self.train() + if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 5bfdba0..d406dda 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -117,6 +117,8 @@ class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): return target_dist.sum(-1) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index e9f9b3b..d236e44 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -107,10 +107,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): ) -> Batch: logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits_BA) - if self.deterministic_eval and not self.training: - act_B = dist.mode - else: - act_B = dist.sample() + act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: @@ -127,6 +124,9 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): return target_q.sum(dim=-1) + self.alpha * dist.entropy() def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore + # set policy in train mode + self.train() + weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index e0ada07..d2e7910 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -210,6 +210,8 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 9c87f9c..c4a1a2d 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -153,6 +153,8 @@ class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): return cast(FQFBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 75d76a2..f868ce4 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -131,6 +131,8 @@ class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): return cast(QuantileRegressionBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 9a148fe..b86540d 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -197,10 +197,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked dist = self.dist_fn(action_dist_input_BD) - if self.deterministic_eval and not self.training: - act_B = dist.mode - else: - act_B = dist.sample() + act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() # act is of dimension BA in continuous case and of dimension B in discrete result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) @@ -214,6 +211,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): *args: Any, **kwargs: Any, ) -> TPGTrainingStats: + # set policy in train mode + self.train() + losses = [] split_batch_size = batch_size or -1 for _ in range(repeat): diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 196cd72..2987114 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -151,6 +151,9 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty *args: Any, **kwargs: Any, ) -> TPPOTrainingStats: + # set policy in train mode + self.train() + losses, clip_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for step in range(repeat): diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 71c36de..9f3e162 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -105,6 +105,8 @@ class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): return super().compute_q_value(logits.mean(2), mask) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: + # set policy in train mode + self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index f9793f4..a216cf9 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -153,7 +153,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): ) -> Batch: (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc_B, scale_B), 1) - if self.deterministic_eval and not self.training: + if self.deterministic_eval and self.is_eval: act_B = dist.mode else: act_B = dist.rsample() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 3b39754..f9ddf7c 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -175,7 +175,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t ) -> DistLogProbBatchProtocol: (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) - if self.deterministic_eval and not self.training: + if self.deterministic_eval and self.is_eval: act_B = dist.mode else: act_B = dist.rsample() @@ -213,6 +213,9 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t ) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore + # set policy in train mode + self.train() + # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 675112f..fa67d3b 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -269,7 +269,6 @@ class BaseTrainer(ABC): assert self.episode_per_test is not None assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 test_result = test_episode( - self.policy, self.test_collector, self.test_fn, self.start_epoch, @@ -309,9 +308,6 @@ class BaseTrainer(ABC): if self.stop_fn_flag: raise StopIteration - # set policy in train mode - self.policy.train() - progress = tqdm.tqdm if self.show_progress else DummyTqdm # perform n step_per_epoch @@ -395,7 +391,6 @@ class BaseTrainer(ABC): assert self.test_collector is not None stop_fn_flag = False test_stat = test_episode( - self.policy, self.test_collector, self.test_fn, self.epoch, @@ -468,7 +463,6 @@ class BaseTrainer(ABC): ): assert self.test_collector is not None test_result = test_episode( - self.policy, self.test_collector, self.test_fn, self.epoch, @@ -481,8 +475,6 @@ class BaseTrainer(ABC): should_stop_training = True self.best_reward = test_result.returns_stat.mean self.best_reward_std = test_result.returns_stat.std - else: - self.policy.train() return result, should_stop_training # TODO: move moving average computation and logging into its own logger diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 7a96ea0..4d990db 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -11,12 +11,10 @@ from tianshou.data import ( SequenceSummaryStats, TimingStats, ) -from tianshou.policy import BasePolicy from tianshou.utils import BaseLogger def test_episode( - policy: BasePolicy, collector: Collector, test_fn: Callable[[int, int | None], None] | None, epoch: int, @@ -27,10 +25,9 @@ def test_episode( ) -> CollectStats: """A simple wrapper of testing policy in collector.""" collector.reset(reset_stats=False) - policy.eval() if test_fn: test_fn(epoch, global_step) - result = collector.collect(n_episode=n_episode) + result = collector.collect(n_episode=n_episode, is_eval=True) if reward_metric: # TODO: move into collector rew = reward_metric(result.returns) result.returns = rew