compatible with torch==1.5.0 (fix #37)
This commit is contained in:
parent
8812eaa502
commit
70290346ea
@ -106,27 +106,25 @@ class SACPolicy(DDPGPolicy):
|
|||||||
done = torch.tensor(batch.done,
|
done = torch.tensor(batch.done,
|
||||||
dtype=torch.float, device=dev)[:, None]
|
dtype=torch.float, device=dev)[:, None]
|
||||||
target_q = (rew + (1. - done) * self._gamma * target_q)
|
target_q = (rew + (1. - done) * self._gamma * target_q)
|
||||||
obs_result = self(batch)
|
|
||||||
a = obs_result.act
|
|
||||||
current_q1, current_q1a = self.critic1(
|
|
||||||
np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a])
|
|
||||||
).split(batch.obs.shape[0])
|
|
||||||
current_q2, current_q2a = self.critic2(
|
|
||||||
np.concatenate([batch.obs, batch.obs]), torch.cat([batch.act, a])
|
|
||||||
).split(batch.obs.shape[0])
|
|
||||||
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
|
|
||||||
current_q1a, current_q2a)).mean()
|
|
||||||
# critic 1
|
# critic 1
|
||||||
|
current_q1 = self.critic1(batch.obs, batch.act)
|
||||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||||
self.critic1_optim.zero_grad()
|
self.critic1_optim.zero_grad()
|
||||||
critic1_loss.backward(retain_graph=True)
|
critic1_loss.backward()
|
||||||
self.critic1_optim.step()
|
self.critic1_optim.step()
|
||||||
# critic 2
|
# critic 2
|
||||||
|
current_q2 = self.critic2(batch.obs, batch.act)
|
||||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||||
self.critic2_optim.zero_grad()
|
self.critic2_optim.zero_grad()
|
||||||
critic2_loss.backward(retain_graph=True)
|
critic2_loss.backward()
|
||||||
self.critic2_optim.step()
|
self.critic2_optim.step()
|
||||||
# actor
|
# actor
|
||||||
|
obs_result = self(batch)
|
||||||
|
a = obs_result.act
|
||||||
|
current_q1a = self.critic1(batch.obs, a)
|
||||||
|
current_q2a = self.critic2(batch.obs, a)
|
||||||
|
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
|
||||||
|
current_q1a, current_q2a)).mean()
|
||||||
self.actor_optim.zero_grad()
|
self.actor_optim.zero_grad()
|
||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
self.actor_optim.step()
|
self.actor_optim.step()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user