Fixing casts to int by to_torch_as(...) calls in policies when using discrete actions (#521)

This commit is contained in:
Kenneth Schröder 2022-02-06 20:42:46 +01:00 committed by GitHub
parent c25926dd8f
commit cd7654bfd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Type, Union
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as
from tianshou.policy import BasePolicy
from tianshou.utils import RunningMeanStd
@ -131,7 +131,7 @@ class PGPolicy(BasePolicy):
result = self(minibatch)
dist = result.dist
act = to_torch_as(minibatch.act, result.act)
ret = to_torch_as(minibatch.returns, result.act)
ret = to_torch(minibatch.returns, torch.float, result.act.device)
log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
loss = -(log_prob * ret).mean()
loss.backward()