This commit is contained in:
Trinkle23897 2020-06-10 12:06:56 +08:00
parent f1951780ab
commit 397e92b0fc
3 changed files with 6 additions and 8 deletions

3
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,3 @@
# Contributing to Tianshou
Please refer to [tianshou.readthedocs.io/en/latest/contributing.html](https://tianshou.readthedocs.io/en/latest/contributing.html).

View File

@ -1,5 +0,0 @@
========================
Contributing to Tianshou
========================
Please refer to https://tianshou.readthedocs.io/en/latest/contributing.html

View File

@ -3,7 +3,7 @@ import numpy as np
from typing import Dict, List, Union, Optional from typing import Dict, List, Union, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data import Batch, ReplayBuffer, to_torch_as
class PGPolicy(BasePolicy): class PGPolicy(BasePolicy):
@ -88,8 +88,8 @@ class PGPolicy(BasePolicy):
for b in batch.split(batch_size): for b in batch.split(batch_size):
self.optim.zero_grad() self.optim.zero_grad()
dist = self(b).dist dist = self(b).dist
a = to_torch(b.act, device=dist.logits.device) a = to_torch_as(b.act, dist.logits)
r = to_torch(b.returns, device=dist.logits.device) r = to_torch_as(b.returns, dist.logits)
loss = -(dist.log_prob(a) * r).sum() loss = -(dist.log_prob(a) * r).sum()
loss.backward() loss.backward()
self.optim.step() self.optim.step()