fix pdqn
This commit is contained in:
parent
b23749463e
commit
6b96f124ae
8
.github/ISSUE_TEMPLATE.md
vendored
8
.github/ISSUE_TEMPLATE.md
vendored
@ -4,13 +4,13 @@
|
||||
+ [ ] documentation request (i.e. "X is missing from the documentation.")
|
||||
+ [ ] new feature request
|
||||
- [ ] I have visited the [source website], and in particular read the [known issues]
|
||||
- [ ] I have searched through the [issue tracker] for duplicates
|
||||
- [ ] I have searched through the [issue categories] for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
```python
|
||||
import tianshou, sys
|
||||
print(tianshou.__version__, sys.version, sys.platform)
|
||||
import tianshou, torch, sys
|
||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
||||
```
|
||||
|
||||
[source website]: https://github.com/thu-ml/tianshou/
|
||||
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
|
||||
[issue tracker]: https://github.com/thu-ml/tianshou/projects/2
|
||||
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
|
||||
|
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
8
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -8,13 +8,13 @@
|
||||
Less important but also useful:
|
||||
|
||||
- [ ] I have visited the [source website], and in particular read the [known issues]
|
||||
- [ ] I have searched through the [issue tracker] for duplicates
|
||||
- [ ] I have searched through the [issue categories] for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
```python
|
||||
import tianshou, sys
|
||||
print(tianshou.__version__, sys.version, sys.platform)
|
||||
import tianshou, torch, sys
|
||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
||||
```
|
||||
|
||||
[source website]: https://github.com/thu-ml/tianshou
|
||||
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
|
||||
[issue tracker]: https://github.com/thu-ml/tianshou/projects/2
|
||||
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
|
||||
|
@ -20,7 +20,7 @@
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
|
||||
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf))
|
||||
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||
|
@ -11,7 +11,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
|
@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from tianshou.data import ReplayBuffer
|
||||
from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -47,6 +47,32 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
print(buf)
|
||||
|
||||
|
||||
def test_priortized_replaybuffer(size=32, bufsize=15):
|
||||
env = MyTestEnv(size)
|
||||
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
|
||||
obs = env.reset()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||
for i, a in enumerate(action_list):
|
||||
obs_next, rew, done, info = env.step(a)
|
||||
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
|
||||
obs = obs_next
|
||||
assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
|
||||
1, rtol=1e-12)
|
||||
data, indice = buf.sample(len(buf) // 2)
|
||||
if len(buf) // 2 == 0:
|
||||
assert len(data) == len(buf)
|
||||
else:
|
||||
assert len(data) == len(buf) // 2
|
||||
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||||
data, indice = buf.sample(len(buf) // 2)
|
||||
buf.update_weight(indice, -data.weight / 2)
|
||||
assert np.isclose(buf.weight[indice], np.power(
|
||||
np.abs(-data.weight / 2), buf._alpha)).all()
|
||||
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_replaybuffer()
|
||||
test_stack()
|
||||
test_priortized_replaybuffer(233333, 200000)
|
||||
|
@ -1,37 +0,0 @@
|
||||
import numpy as np
|
||||
from tianshou.data import PrioritizedReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
else: # pytest
|
||||
from test.base.env import MyTestEnv
|
||||
|
||||
|
||||
def test_replaybuffer(size=32, bufsize=15):
|
||||
env = MyTestEnv(size)
|
||||
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
|
||||
obs = env.reset()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||
for i, a in enumerate(action_list):
|
||||
obs_next, rew, done, info = env.step(a)
|
||||
buf.add(obs, a, rew, done, obs_next, info, np.random.randn()-0.5)
|
||||
obs = obs_next
|
||||
assert np.isclose(np.sum((buf.weight/buf._weight_sum)[:buf._size]), 1,
|
||||
rtol=1e-12)
|
||||
data, indice = buf.sample(len(buf) // 2)
|
||||
if len(buf)//2 == 0:
|
||||
assert len(data) == len(buf)
|
||||
else:
|
||||
assert len(data) == len(buf)//2
|
||||
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||||
data, indice = buf.sample(len(buf) // 2)
|
||||
buf.update_weight(indice, -data.weight/2)
|
||||
assert np.isclose(buf.weight[indice], np.power(
|
||||
np.abs(-data.weight/2), buf._alpha)).all()
|
||||
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_replaybuffer(233333, 200000)
|
||||
print("pass")
|
@ -1,7 +1,7 @@
|
||||
from tianshou import data, env, utils, policy, trainer, \
|
||||
exploration
|
||||
|
||||
__version__ = '0.2.1'
|
||||
__version__ = '0.2.2'
|
||||
__all__ = [
|
||||
'env',
|
||||
'data',
|
||||
|
@ -104,12 +104,15 @@ class DQNPolicy(BasePolicy):
|
||||
r = batch.returns
|
||||
if isinstance(r, np.ndarray):
|
||||
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
||||
td = r-q
|
||||
buffer.update_weight(indice, td.detach().numpy())
|
||||
td = r - q
|
||||
buffer.update_weight(indice, td.detach().cpu().numpy())
|
||||
impt_weight = torch.tensor(batch.impt_weight,
|
||||
device=q.device, dtype=torch.float)
|
||||
loss = (td.pow(2)*impt_weight).mean()
|
||||
batch.loss = loss
|
||||
loss = (td.pow(2) * impt_weight).mean()
|
||||
if not hasattr(batch, 'loss'):
|
||||
batch.loss = loss
|
||||
else:
|
||||
batch.loss += loss
|
||||
return batch
|
||||
|
||||
def forward(self, batch, state=None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user