Fixing casts to int by to_torch_as(...) calls in policies when using discrete actions (#521)
This commit is contained in:
parent
c25926dd8f
commit
cd7654bfd5
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Type, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.utils import RunningMeanStd
|
from tianshou.utils import RunningMeanStd
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ class PGPolicy(BasePolicy):
|
|||||||
result = self(minibatch)
|
result = self(minibatch)
|
||||||
dist = result.dist
|
dist = result.dist
|
||||||
act = to_torch_as(minibatch.act, result.act)
|
act = to_torch_as(minibatch.act, result.act)
|
||||||
ret = to_torch_as(minibatch.returns, result.act)
|
ret = to_torch(minibatch.returns, torch.float, result.act.device)
|
||||||
log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
|
log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
|
||||||
loss = -(log_prob * ret).mean()
|
loss = -(log_prob * ret).mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user