Set torch train mode in BasePolicy.update instead of in each .learn implementation,
as this is less prone to errors
This commit is contained in:
parent
a2b9d7c7d8
commit
4f16494609
@ -25,6 +25,7 @@ from tianshou.data.types import (
|
||||
)
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
from tianshou.utils.print import DataclassPPrintMixin
|
||||
from tianshou.utils.torch_utils import in_train_mode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -513,7 +514,8 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
batch, indices = buffer.sample(sample_size)
|
||||
self.updating = True
|
||||
batch = self.process_fn(batch, buffer, indices)
|
||||
training_stat = self.learn(batch, **kwargs)
|
||||
with in_train_mode(self):
|
||||
training_stat = self.learn(batch, **kwargs)
|
||||
self.post_process_fn(batch, buffer, indices)
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
@ -165,9 +165,6 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> TA2CTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
split_batch_size = batch_size or -1
|
||||
for _ in range(repeat):
|
||||
|
@ -163,9 +163,6 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -117,8 +117,6 @@ class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
|
||||
return target_dist.sum(-1)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -124,9 +124,6 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
return target_q.sum(dim=-1) + self.alpha * dist.entropy()
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
weight = batch.pop("weight", 1.0)
|
||||
target_q = batch.returns.flatten()
|
||||
act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long)
|
||||
|
@ -210,8 +210,6 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -153,8 +153,6 @@ class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
|
||||
return cast(FQFBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
weight = batch.pop("weight", 1.0)
|
||||
|
@ -131,8 +131,6 @@ class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
|
||||
return cast(QuantileRegressionBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -211,9 +211,6 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> TPGTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
losses = []
|
||||
split_batch_size = batch_size or -1
|
||||
for _ in range(repeat):
|
||||
|
@ -151,9 +151,6 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> TPPOTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
split_batch_size = batch_size or -1
|
||||
for step in range(repeat):
|
||||
|
@ -105,8 +105,6 @@ class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
|
||||
return super().compute_q_value(logits.mean(2), mask)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats:
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
if self._target and self._iter % self.freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -213,9 +213,6 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore
|
||||
# set policy in train mode
|
||||
self.train()
|
||||
|
||||
# critic 1&2
|
||||
td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
|
||||
td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim)
|
||||
|
Loading…
x
Reference in New Issue
Block a user