compatible with torch==1.5.0 (fix #37)

This commit is contained in:
Trinkle23897 2020-04-26 11:04:45 +08:00
parent 8812eaa502
commit 70290346ea

View File

@ -106,27 +106,25 @@ class SACPolicy(DDPGPolicy):
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
obs_result = self(batch)
a = obs_result.act
current_q1, current_q1a = self.critic1(
np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a])
).split(batch.obs.shape[0])
current_q2, current_q2a = self.critic2(
np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a])
).split(batch.obs.shape[0])
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
current_q1a, current_q2a)).mean()
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward(retain_graph=True)
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act)
critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward(retain_graph=True)
critic2_loss.backward()
self.critic2_optim.step()
# actor
obs_result = self(batch)
a = obs_result.act
current_q1a = self.critic1(batch.obs, a)
current_q2a = self.critic2(batch.obs, a)
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
current_q1a, current_q2a)).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()