compatible with torch==1.5.0 (fix #37)
This commit is contained in:
parent
8812eaa502
commit
70290346ea
@ -106,27 +106,25 @@ class SACPolicy(DDPGPolicy):
|
||||
done = torch.tensor(batch.done,
|
||||
dtype=torch.float, device=dev)[:, None]
|
||||
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||
obs_result = self(batch)
|
||||
a = obs_result.act
|
||||
current_q1, current_q1a = self.critic1(
|
||||
np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a])
|
||||
).split(batch.obs.shape[0])
|
||||
current_q2, current_q2a = self.critic2(
|
||||
np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a])
|
||||
).split(batch.obs.shape[0])
|
||||
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
|
||||
current_q1a, current_q2a)).mean()
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act)
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward(retain_graph=True)
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act)
|
||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward(retain_graph=True)
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
# actor
|
||||
obs_result = self(batch)
|
||||
a = obs_result.act
|
||||
current_q1a = self.critic1(batch.obs, a)
|
||||
current_q2a = self.critic2(batch.obs, a)
|
||||
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
|
||||
current_q1a, current_q2a)).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user