Set torch train mode in BasePolicy.update instead of in each .learn implementation,
as this is less prone to errors
This commit is contained in:
parent
a2b9d7c7d8
commit
4f16494609
@ -25,6 +25,7 @@ from tianshou.data.types import (
|
|||||||
)
|
)
|
||||||
from tianshou.utils import MultipleLRSchedulers
|
from tianshou.utils import MultipleLRSchedulers
|
||||||
from tianshou.utils.print import DataclassPPrintMixin
|
from tianshou.utils.print import DataclassPPrintMixin
|
||||||
|
from tianshou.utils.torch_utils import in_train_mode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -513,6 +514,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
|||||||
batch, indices = buffer.sample(sample_size)
|
batch, indices = buffer.sample(sample_size)
|
||||||
self.updating = True
|
self.updating = True
|
||||||
batch = self.process_fn(batch, buffer, indices)
|
batch = self.process_fn(batch, buffer, indices)
|
||||||
|
with in_train_mode(self):
|
||||||
training_stat = self.learn(batch, **kwargs)
|
training_stat = self.learn(batch, **kwargs)
|
||||||
self.post_process_fn(batch, buffer, indices)
|
self.post_process_fn(batch, buffer, indices)
|
||||||
if self.lr_scheduler is not None:
|
if self.lr_scheduler is not None:
|
||||||
|
@ -165,9 +165,6 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
|
|||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> TA2CTrainingStats:
|
) -> TA2CTrainingStats:
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
split_batch_size = batch_size or -1
|
split_batch_size = batch_size or -1
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
|
@ -163,9 +163,6 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
|
|||||||
return cast(ModelOutputBatchProtocol, result)
|
return cast(ModelOutputBatchProtocol, result)
|
||||||
|
|
||||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
if self._target and self._iter % self.freq == 0:
|
if self._target and self._iter % self.freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
@ -117,8 +117,6 @@ class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
|
|||||||
return target_dist.sum(-1)
|
return target_dist.sum(-1)
|
||||||
|
|
||||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats:
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats:
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
if self._target and self._iter % self.freq == 0:
|
if self._target and self._iter % self.freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
@ -124,9 +124,6 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
|||||||
return target_q.sum(dim=-1) + self.alpha * dist.entropy()
|
return target_q.sum(dim=-1) + self.alpha * dist.entropy()
|
||||||
|
|
||||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
weight = batch.pop("weight", 1.0)
|
weight = batch.pop("weight", 1.0)
|
||||||
target_q = batch.returns.flatten()
|
target_q = batch.returns.flatten()
|
||||||
act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long)
|
act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long)
|
||||||
|
@ -210,8 +210,6 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
|||||||
return cast(ModelOutputBatchProtocol, result)
|
return cast(ModelOutputBatchProtocol, result)
|
||||||
|
|
||||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
if self._target and self._iter % self.freq == 0:
|
if self._target and self._iter % self.freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
@ -153,8 +153,6 @@ class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
|
|||||||
return cast(FQFBatchProtocol, result)
|
return cast(FQFBatchProtocol, result)
|
||||||
|
|
||||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats:
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats:
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
if self._target and self._iter % self.freq == 0:
|
if self._target and self._iter % self.freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
weight = batch.pop("weight", 1.0)
|
weight = batch.pop("weight", 1.0)
|
||||||
|
@ -131,8 +131,6 @@ class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
|
|||||||
return cast(QuantileRegressionBatchProtocol, result)
|
return cast(QuantileRegressionBatchProtocol, result)
|
||||||
|
|
||||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats:
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats:
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
if self._target and self._iter % self.freq == 0:
|
if self._target and self._iter % self.freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
@ -211,9 +211,6 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
|||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> TPGTrainingStats:
|
) -> TPGTrainingStats:
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
split_batch_size = batch_size or -1
|
split_batch_size = batch_size or -1
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
|
@ -151,9 +151,6 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
|
|||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> TPPOTrainingStats:
|
) -> TPPOTrainingStats:
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
split_batch_size = batch_size or -1
|
split_batch_size = batch_size or -1
|
||||||
for step in range(repeat):
|
for step in range(repeat):
|
||||||
|
@ -105,8 +105,6 @@ class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
|
|||||||
return super().compute_q_value(logits.mean(2), mask)
|
return super().compute_q_value(logits.mean(2), mask)
|
||||||
|
|
||||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats:
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats:
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
if self._target and self._iter % self.freq == 0:
|
if self._target and self._iter % self.freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
@ -213,9 +213,6 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
|||||||
)
|
)
|
||||||
|
|
||||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore
|
||||||
# set policy in train mode
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
# critic 1&2
|
# critic 1&2
|
||||||
td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
|
td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
|
||||||
td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim)
|
td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user