This commit is contained in:
Trinkle23897 2020-06-10 12:06:56 +08:00
parent f1951780ab
commit 397e92b0fc
3 changed files with 6 additions and 8 deletions

3
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,3 @@
# Contributing to Tianshou
Please refer to [tianshou.readthedocs.io/en/latest/contributing.html](https://tianshou.readthedocs.io/en/latest/contributing.html).

View File

@ -1,5 +0,0 @@
========================
Contributing to Tianshou
========================
Please refer to https://tianshou.readthedocs.io/en/latest/contributing.html

View File

@ -3,7 +3,7 @@ import numpy as np
from typing import Dict, List, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class PGPolicy(BasePolicy):
@ -88,8 +88,8 @@ class PGPolicy(BasePolicy):
for b in batch.split(batch_size):
self.optim.zero_grad()
dist = self(b).dist
a = to_torch(b.act, device=dist.logits.device)
r = to_torch(b.returns, device=dist.logits.device)
a = to_torch_as(b.act, dist.logits)
r = to_torch_as(b.returns, dist.logits)
loss = -(dist.log_prob(a) * r).sum()
loss.backward()
self.optim.step()