fix #77
This commit is contained in:
parent
f1951780ab
commit
397e92b0fc
3
CONTRIBUTING.md
Normal file
3
CONTRIBUTING.md
Normal file
@ -0,0 +1,3 @@
|
||||
# Contributing to Tianshou
|
||||
|
||||
Please refer to [tianshou.readthedocs.io/en/latest/contributing.html](https://tianshou.readthedocs.io/en/latest/contributing.html).
|
@ -1,5 +0,0 @@
|
||||
========================
|
||||
Contributing to Tianshou
|
||||
========================
|
||||
|
||||
Please refer to https://tianshou.readthedocs.io/en/latest/contributing.html
|
@ -3,7 +3,7 @@ import numpy as np
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
|
||||
|
||||
class PGPolicy(BasePolicy):
|
||||
@ -88,8 +88,8 @@ class PGPolicy(BasePolicy):
|
||||
for b in batch.split(batch_size):
|
||||
self.optim.zero_grad()
|
||||
dist = self(b).dist
|
||||
a = to_torch(b.act, device=dist.logits.device)
|
||||
r = to_torch(b.returns, device=dist.logits.device)
|
||||
a = to_torch_as(b.act, dist.logits)
|
||||
r = to_torch_as(b.returns, dist.logits)
|
||||
loss = -(dist.log_prob(a) * r).sum()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user