fix #77
This commit is contained in:
parent
f1951780ab
commit
397e92b0fc
3
CONTRIBUTING.md
Normal file
3
CONTRIBUTING.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Contributing to Tianshou
|
||||||
|
|
||||||
|
Please refer to [tianshou.readthedocs.io/en/latest/contributing.html](https://tianshou.readthedocs.io/en/latest/contributing.html).
|
@ -1,5 +0,0 @@
|
|||||||
========================
|
|
||||||
Contributing to Tianshou
|
|
||||||
========================
|
|
||||||
|
|
||||||
Please refer to https://tianshou.readthedocs.io/en/latest/contributing.html
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||||||
from typing import Dict, List, Union, Optional
|
from typing import Dict, List, Union, Optional
|
||||||
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||||
|
|
||||||
|
|
||||||
class PGPolicy(BasePolicy):
|
class PGPolicy(BasePolicy):
|
||||||
@ -88,8 +88,8 @@ class PGPolicy(BasePolicy):
|
|||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
a = to_torch(b.act, device=dist.logits.device)
|
a = to_torch_as(b.act, dist.logits)
|
||||||
r = to_torch(b.returns, device=dist.logits.device)
|
r = to_torch_as(b.returns, dist.logits)
|
||||||
loss = -(dist.log_prob(a) * r).sum()
|
loss = -(dist.log_prob(a) * r).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user