fix pdqn
This commit is contained in:
		
							parent
							
								
									b23749463e
								
							
						
					
					
						commit
						6b96f124ae
					
				
							
								
								
									
										8
									
								
								.github/ISSUE_TEMPLATE.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/ISSUE_TEMPLATE.md
									
									
									
									
										vendored
									
									
								
							@ -4,13 +4,13 @@
 | 
			
		||||
    + [ ] documentation request (i.e. "X is missing from the documentation.")
 | 
			
		||||
    + [ ] new feature request
 | 
			
		||||
- [ ] I have visited the [source website], and in particular read the [known issues]
 | 
			
		||||
- [ ] I have searched through the [issue tracker] for duplicates
 | 
			
		||||
- [ ] I have searched through the [issue categories] for duplicates
 | 
			
		||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
 | 
			
		||||
  ```python
 | 
			
		||||
  import tianshou, sys
 | 
			
		||||
  print(tianshou.__version__, sys.version, sys.platform)
 | 
			
		||||
  import tianshou, torch, sys
 | 
			
		||||
  print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
 | 
			
		||||
  ```
 | 
			
		||||
 | 
			
		||||
  [source website]: https://github.com/thu-ml/tianshou/
 | 
			
		||||
  [known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
 | 
			
		||||
  [issue tracker]: https://github.com/thu-ml/tianshou/projects/2
 | 
			
		||||
  [issue categories]: https://github.com/thu-ml/tianshou/projects/2
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/PULL_REQUEST_TEMPLATE.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/PULL_REQUEST_TEMPLATE.md
									
									
									
									
										vendored
									
									
								
							@ -8,13 +8,13 @@
 | 
			
		||||
Less important but also useful:
 | 
			
		||||
 | 
			
		||||
- [ ] I have visited the [source website], and in particular read the [known issues]
 | 
			
		||||
- [ ] I have searched through the [issue tracker] for duplicates
 | 
			
		||||
- [ ] I have searched through the [issue categories] for duplicates
 | 
			
		||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
 | 
			
		||||
  ```python
 | 
			
		||||
  import tianshou, sys
 | 
			
		||||
  print(tianshou.__version__, sys.version, sys.platform)
 | 
			
		||||
  import tianshou, torch, sys
 | 
			
		||||
  print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
 | 
			
		||||
  ```
 | 
			
		||||
 | 
			
		||||
  [source website]: https://github.com/thu-ml/tianshou
 | 
			
		||||
  [known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
 | 
			
		||||
  [issue tracker]: https://github.com/thu-ml/tianshou/projects/2
 | 
			
		||||
  [issue categories]: https://github.com/thu-ml/tianshou/projects/2
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@
 | 
			
		||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
 | 
			
		||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
 | 
			
		||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
 | 
			
		||||
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf))
 | 
			
		||||
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)
 | 
			
		||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
 | 
			
		||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
 | 
			
		||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ Welcome to Tianshou!
 | 
			
		||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
 | 
			
		||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
 | 
			
		||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
 | 
			
		||||
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf`_
 | 
			
		||||
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf>`_
 | 
			
		||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
 | 
			
		||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
 | 
			
		||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
from tianshou.data import ReplayBuffer
 | 
			
		||||
from tianshou.data import ReplayBuffer, PrioritizedReplayBuffer
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from env import MyTestEnv
 | 
			
		||||
@ -47,6 +47,32 @@ def test_stack(size=5, bufsize=9, stack_num=4):
 | 
			
		||||
    print(buf)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_priortized_replaybuffer(size=32, bufsize=15):
 | 
			
		||||
    env = MyTestEnv(size)
 | 
			
		||||
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
 | 
			
		||||
    obs = env.reset()
 | 
			
		||||
    action_list = [1] * 5 + [0] * 10 + [1] * 10
 | 
			
		||||
    for i, a in enumerate(action_list):
 | 
			
		||||
        obs_next, rew, done, info = env.step(a)
 | 
			
		||||
        buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5)
 | 
			
		||||
        obs = obs_next
 | 
			
		||||
        assert np.isclose(np.sum((buf.weight / buf._weight_sum)[:buf._size]),
 | 
			
		||||
                          1, rtol=1e-12)
 | 
			
		||||
        data, indice = buf.sample(len(buf) // 2)
 | 
			
		||||
        if len(buf) // 2 == 0:
 | 
			
		||||
            assert len(data) == len(buf)
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(data) == len(buf) // 2
 | 
			
		||||
        assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
 | 
			
		||||
        assert np.isclose(buf._weight_sum, (buf.weight).sum())
 | 
			
		||||
    data, indice = buf.sample(len(buf) // 2)
 | 
			
		||||
    buf.update_weight(indice, -data.weight / 2)
 | 
			
		||||
    assert np.isclose(buf.weight[indice], np.power(
 | 
			
		||||
        np.abs(-data.weight / 2), buf._alpha)).all()
 | 
			
		||||
    assert np.isclose(buf._weight_sum, (buf.weight).sum())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    test_replaybuffer()
 | 
			
		||||
    test_stack()
 | 
			
		||||
    test_priortized_replaybuffer(233333, 200000)
 | 
			
		||||
 | 
			
		||||
@ -1,37 +0,0 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
from tianshou.data import PrioritizedReplayBuffer
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from env import MyTestEnv
 | 
			
		||||
else:  # pytest
 | 
			
		||||
    from test.base.env import MyTestEnv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_replaybuffer(size=32, bufsize=15):
 | 
			
		||||
    env = MyTestEnv(size)
 | 
			
		||||
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
 | 
			
		||||
    obs = env.reset()
 | 
			
		||||
    action_list = [1] * 5 + [0] * 10 + [1] * 10
 | 
			
		||||
    for i, a in enumerate(action_list):
 | 
			
		||||
        obs_next, rew, done, info = env.step(a)
 | 
			
		||||
        buf.add(obs, a, rew, done, obs_next, info, np.random.randn()-0.5)
 | 
			
		||||
        obs = obs_next
 | 
			
		||||
        assert np.isclose(np.sum((buf.weight/buf._weight_sum)[:buf._size]), 1,
 | 
			
		||||
                          rtol=1e-12)
 | 
			
		||||
        data, indice = buf.sample(len(buf) // 2)
 | 
			
		||||
        if len(buf)//2 == 0:
 | 
			
		||||
            assert len(data) == len(buf)
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(data) == len(buf)//2
 | 
			
		||||
        assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
 | 
			
		||||
        assert np.isclose(buf._weight_sum, (buf.weight).sum())
 | 
			
		||||
    data, indice = buf.sample(len(buf) // 2)
 | 
			
		||||
    buf.update_weight(indice, -data.weight/2)
 | 
			
		||||
    assert np.isclose(buf.weight[indice], np.power(
 | 
			
		||||
        np.abs(-data.weight/2), buf._alpha)).all()
 | 
			
		||||
    assert np.isclose(buf._weight_sum, (buf.weight).sum())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    test_replaybuffer(233333, 200000)
 | 
			
		||||
    print("pass")
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
from tianshou import data, env, utils, policy, trainer, \
 | 
			
		||||
    exploration
 | 
			
		||||
 | 
			
		||||
__version__ = '0.2.1'
 | 
			
		||||
__version__ = '0.2.2'
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'env',
 | 
			
		||||
    'data',
 | 
			
		||||
 | 
			
		||||
@ -104,12 +104,15 @@ class DQNPolicy(BasePolicy):
 | 
			
		||||
            r = batch.returns
 | 
			
		||||
            if isinstance(r, np.ndarray):
 | 
			
		||||
                r = torch.tensor(r, device=q.device, dtype=q.dtype)
 | 
			
		||||
            td = r-q
 | 
			
		||||
            buffer.update_weight(indice, td.detach().numpy())
 | 
			
		||||
            td = r - q
 | 
			
		||||
            buffer.update_weight(indice, td.detach().cpu().numpy())
 | 
			
		||||
            impt_weight = torch.tensor(batch.impt_weight,
 | 
			
		||||
                                       device=q.device, dtype=torch.float)
 | 
			
		||||
            loss = (td.pow(2)*impt_weight).mean()
 | 
			
		||||
            loss = (td.pow(2) * impt_weight).mean()
 | 
			
		||||
            if not hasattr(batch, 'loss'):
 | 
			
		||||
                batch.loss = loss
 | 
			
		||||
            else:
 | 
			
		||||
                batch.loss += loss
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
    def forward(self, batch, state=None,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user