From 70290346eaf7aade89a60ff7cfb0f7d11d02d918 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 26 Apr 2020 11:04:45 +0800 Subject: [PATCH] compatible with torch==1.5.0 (fix #37) --- tianshou/policy/modelfree/sac.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 8b27fde..d701aca 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -106,27 +106,25 @@ class SACPolicy(DDPGPolicy): done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] target_q = (rew + (1. - done) * self._gamma * target_q) - obs_result = self(batch) - a = obs_result.act - current_q1, current_q1a = self.critic1( - np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a]) - ).split(batch.obs.shape[0]) - current_q2, current_q2a = self.critic2( - np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a]) - ).split(batch.obs.shape[0]) - actor_loss = (self._alpha * obs_result.log_prob - torch.min( - current_q1a, current_q2a)).mean() # critic 1 + current_q1 = self.critic1(batch.obs, batch.act) critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() - critic1_loss.backward(retain_graph=True) + critic1_loss.backward() self.critic1_optim.step() # critic 2 + current_q2 = self.critic2(batch.obs, batch.act) critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() - critic2_loss.backward(retain_graph=True) + critic2_loss.backward() self.critic2_optim.step() # actor + obs_result = self(batch) + a = obs_result.act + current_q1a = self.critic1(batch.obs, a) + current_q2a = self.critic2(batch.obs, a) + actor_loss = (self._alpha * obs_result.log_prob - torch.min( + current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step()