Set torch train mode in BasePolicy.update instead of in each .learn implementation,

as this is less prone to errors
This commit is contained in:
Dominik Jain 2024-05-02 11:51:08 +02:00
parent a2b9d7c7d8
commit 4f16494609
12 changed files with 3 additions and 29 deletions

View File

@ -25,6 +25,7 @@ from tianshou.data.types import (
)
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import in_train_mode
logger = logging.getLogger(__name__)
@ -513,7 +514,8 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
batch, indices = buffer.sample(sample_size)
self.updating = True
batch = self.process_fn(batch, buffer, indices)
training_stat = self.learn(batch, **kwargs)
with in_train_mode(self):
training_stat = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indices)
if self.lr_scheduler is not None:
self.lr_scheduler.step()

View File

@ -165,9 +165,6 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
*args: Any,
**kwargs: Any,
) -> TA2CTrainingStats:
# set policy in train mode
self.train()
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
split_batch_size = batch_size or -1
for _ in range(repeat):

View File

@ -163,9 +163,6 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0:
self.sync_weight()
self.optim.zero_grad()

View File

@ -117,8 +117,6 @@ class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
return target_dist.sum(-1)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0:
self.sync_weight()
self.optim.zero_grad()

View File

@ -124,9 +124,6 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
return target_q.sum(dim=-1) + self.alpha * dist.entropy()
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore
# set policy in train mode
self.train()
weight = batch.pop("weight", 1.0)
target_q = batch.returns.flatten()
act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long)

View File

@ -210,8 +210,6 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0:
self.sync_weight()
self.optim.zero_grad()

View File

@ -153,8 +153,6 @@ class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
return cast(FQFBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0:
self.sync_weight()
weight = batch.pop("weight", 1.0)

View File

@ -131,8 +131,6 @@ class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
return cast(QuantileRegressionBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0:
self.sync_weight()
self.optim.zero_grad()

View File

@ -211,9 +211,6 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
*args: Any,
**kwargs: Any,
) -> TPGTrainingStats:
# set policy in train mode
self.train()
losses = []
split_batch_size = batch_size or -1
for _ in range(repeat):

View File

@ -151,9 +151,6 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
*args: Any,
**kwargs: Any,
) -> TPPOTrainingStats:
# set policy in train mode
self.train()
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
split_batch_size = batch_size or -1
for step in range(repeat):

View File

@ -105,8 +105,6 @@ class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
return super().compute_q_value(logits.mean(2), mask)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats:
# set policy in train mode
self.train()
if self._target and self._iter % self.freq == 0:
self.sync_weight()
self.optim.zero_grad()

View File

@ -213,9 +213,6 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore
# set policy in train mode
self.train()
# critic 1&2
td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim)