From 4f16494609a33de67bb6454aaf3e64e523d0de60 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 2 May 2024 11:51:08 +0200 Subject: [PATCH] Set torch train mode in BasePolicy.update instead of in each .learn implementation, as this is less prone to errors --- tianshou/policy/base.py | 4 +++- tianshou/policy/modelfree/a2c.py | 3 --- tianshou/policy/modelfree/bdq.py | 3 --- tianshou/policy/modelfree/c51.py | 2 -- tianshou/policy/modelfree/discrete_sac.py | 3 --- tianshou/policy/modelfree/dqn.py | 2 -- tianshou/policy/modelfree/fqf.py | 2 -- tianshou/policy/modelfree/iqn.py | 2 -- tianshou/policy/modelfree/pg.py | 3 --- tianshou/policy/modelfree/ppo.py | 3 --- tianshou/policy/modelfree/qrdqn.py | 2 -- tianshou/policy/modelfree/sac.py | 3 --- 12 files changed, 3 insertions(+), 29 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 450bf52..e566d0c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -25,6 +25,7 @@ from tianshou.data.types import ( ) from tianshou.utils import MultipleLRSchedulers from tianshou.utils.print import DataclassPPrintMixin +from tianshou.utils.torch_utils import in_train_mode logger = logging.getLogger(__name__) @@ -513,7 +514,8 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): batch, indices = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indices) - training_stat = self.learn(batch, **kwargs) + with in_train_mode(self): + training_stat = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: self.lr_scheduler.step() diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index f055c41..d41ccb4 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -165,9 +165,6 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ *args: Any, **kwargs: Any, ) -> TA2CTrainingStats: - # set policy in train mode - self.train() - losses, actor_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index 80c17be..d7196a9 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -163,9 +163,6 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: - # set policy in train mode - self.train() - if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index d406dda..5bfdba0 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -117,8 +117,6 @@ class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): return target_dist.sum(-1) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index d236e44..7e731b1 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -124,9 +124,6 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): return target_q.sum(dim=-1) + self.alpha * dist.entropy() def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore - # set policy in train mode - self.train() - weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index d2e7910..e0ada07 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -210,8 +210,6 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index c4a1a2d..9c87f9c 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -153,8 +153,6 @@ class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): return cast(FQFBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index f868ce4..75d76a2 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -131,8 +131,6 @@ class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): return cast(QuantileRegressionBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index b86540d..4792db8 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -211,9 +211,6 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): *args: Any, **kwargs: Any, ) -> TPGTrainingStats: - # set policy in train mode - self.train() - losses = [] split_batch_size = batch_size or -1 for _ in range(repeat): diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 2987114..196cd72 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -151,9 +151,6 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty *args: Any, **kwargs: Any, ) -> TPPOTrainingStats: - # set policy in train mode - self.train() - losses, clip_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for step in range(repeat): diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 9f3e162..71c36de 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -105,8 +105,6 @@ class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): return super().compute_q_value(logits.mean(2), mask) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: - # set policy in train mode - self.train() if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index f9ddf7c..3dbea75 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -213,9 +213,6 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t ) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore - # set policy in train mode - self.train() - # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim)