This commit is contained in:
Trinkle23897 2020-04-26 15:11:20 +08:00
parent b23749463e
commit 6b96f124ae
8 changed files with 45 additions and 53 deletions

View File

@ -4,13 +4,13 @@
+ [ ] documentation request (i.e. "X is missing from the documentation.")
+ [ ] new feature request
- [ ] I have visited the [source website], and in particular read the [known issues]
- [ ] I have searched through the [issue tracker] for duplicates
- [ ] I have searched through the [issue categories] for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, sys
print(tianshou.__version__, sys.version, sys.platform)
import tianshou, torch, sys
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
```
[source website]: https://github.com/thu-ml/tianshou/
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
[issue tracker]: https://github.com/thu-ml/tianshou/projects/2
[issue categories]: https://github.com/thu-ml/tianshou/projects/2

View File

@ -8,13 +8,13 @@
Less important but also useful:
- [ ] I have visited the [source website], and in particular read the [known issues]
- [ ] I have searched through the [issue tracker] for duplicates
- [ ] I have searched through the [issue categories] for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, sys
print(tianshou.__version__, sys.version, sys.platform)
import tianshou, torch, sys
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
```
[source website]: https://github.com/thu-ml/tianshou
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
[issue tracker]: https://github.com/thu-ml/tianshou/projects/2
[issue categories]: https://github.com/thu-ml/tianshou/projects/2

View File

@ -20,7 +20,7 @@
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf))
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)

View File

@ -11,7 +11,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf`_
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_

View File

@ -1,5 +1,5 @@
import numpy as np
from tianshou.data import ReplayBuffer
from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
@ -47,6 +47,32 @@ def test_stack(size=5, bufsize=9, stack_num=4):
print(buf)
def test_priortized_replaybuffer(size=32, bufsize=15):
env = MyTestEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
obs = obs_next
assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
1, rtol=1e-12)
data, indice = buf.sample(len(buf) // 2)
if len(buf) // 2 == 0:
assert len(data) == len(buf)
else:
assert len(data) == len(buf) // 2
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
assert np.isclose(buf._weight_sum, (buf.weight).sum())
data, indice = buf.sample(len(buf) // 2)
buf.update_weight(indice, -data.weight / 2)
assert np.isclose(buf.weight[indice], np.power(
np.abs(-data.weight / 2), buf._alpha)).all()
assert np.isclose(buf._weight_sum, (buf.weight).sum())
if __name__ == '__main__':
test_replaybuffer()
test_stack()
test_priortized_replaybuffer(233333, 200000)

View File

@ -1,37 +0,0 @@
import numpy as np
from tianshou.data import PrioritizedReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
else: # pytest
from test.base.env import MyTestEnv
def test_replaybuffer(size=32, bufsize=15):
env = MyTestEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info, np.random.randn()-0.5)
obs = obs_next
assert np.isclose(np.sum((buf.weight/buf._weight_sum)[:buf._size]), 1,
rtol=1e-12)
data, indice = buf.sample(len(buf) // 2)
if len(buf)//2 == 0:
assert len(data) == len(buf)
else:
assert len(data) == len(buf)//2
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
assert np.isclose(buf._weight_sum, (buf.weight).sum())
data, indice = buf.sample(len(buf) // 2)
buf.update_weight(indice, -data.weight/2)
assert np.isclose(buf.weight[indice], np.power(
np.abs(-data.weight/2), buf._alpha)).all()
assert np.isclose(buf._weight_sum, (buf.weight).sum())
if __name__ == "__main__":
test_replaybuffer(233333, 200000)
print("pass")

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, \
exploration
__version__ = '0.2.1'
__version__ = '0.2.2'
__all__ = [
'env',
'data',

View File

@ -104,12 +104,15 @@ class DQNPolicy(BasePolicy):
r = batch.returns
if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype)
td = r-q
buffer.update_weight(indice, td.detach().numpy())
td = r - q
buffer.update_weight(indice, td.detach().cpu().numpy())
impt_weight = torch.tensor(batch.impt_weight,
device=q.device, dtype=torch.float)
loss = (td.pow(2)*impt_weight).mean()
batch.loss = loss
loss = (td.pow(2) * impt_weight).mean()
if not hasattr(batch, 'loss'):
batch.loss = loss
else:
batch.loss += loss
return batch
def forward(self, batch, state=None,