parent
							
								
									86572c66d4
								
							
						
					
					
						commit
						6da80e045a
					
				
							
								
								
									
										1
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/pytest.yml
									
									
									
									
										vendored
									
									
								
							| @ -11,7 +11,6 @@ jobs: | |||||||
|     strategy: |     strategy: | ||||||
|       matrix: |       matrix: | ||||||
|         python-version: [3.6, 3.7, 3.8] |         python-version: [3.6, 3.7, 3.8] | ||||||
| 
 |  | ||||||
|     steps: |     steps: | ||||||
|     - uses: actions/checkout@v2 |     - uses: actions/checkout@v2 | ||||||
|     - name: Set up Python ${{ matrix.python-version }} |     - name: Set up Python ${{ matrix.python-version }} | ||||||
|  | |||||||
| @ -8,9 +8,9 @@ class MyTestEnv(gym.Env): | |||||||
|         self.sleep = sleep |         self.sleep = sleep | ||||||
|         self.reset() |         self.reset() | ||||||
| 
 | 
 | ||||||
|     def reset(self): |     def reset(self, state=0): | ||||||
|         self.done = False |         self.done = False | ||||||
|         self.index = 0 |         self.index = state | ||||||
|         return self.index |         return self.index | ||||||
| 
 | 
 | ||||||
|     def step(self, action): |     def step(self, action): | ||||||
|  | |||||||
| @ -17,6 +17,7 @@ def test_batch(): | |||||||
|     batch.obs = np.arange(5) |     batch.obs = np.arange(5) | ||||||
|     for i, b in enumerate(batch.split(1, permute=False)): |     for i, b in enumerate(batch.split(1, permute=False)): | ||||||
|         assert b.obs == batch[i].obs |         assert b.obs == batch[i].obs | ||||||
|  |     print(batch) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  | |||||||
| @ -1,3 +1,4 @@ | |||||||
|  | import numpy as np | ||||||
| from tianshou.data import ReplayBuffer | from tianshou.data import ReplayBuffer | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
| @ -28,5 +29,24 @@ def test_replaybuffer(size=10, bufsize=20): | |||||||
|     assert buf2[-1].obs == buf[4].obs |     assert buf2[-1].obs == buf[4].obs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_stack(size=5, bufsize=9, stack_num=4): | ||||||
|  |     env = MyTestEnv(size) | ||||||
|  |     buf = ReplayBuffer(bufsize, stack_num) | ||||||
|  |     obs = env.reset(1) | ||||||
|  |     for i in range(15): | ||||||
|  |         obs_next, rew, done, info = env.step(1) | ||||||
|  |         buf.add(obs, 1, rew, done, None, info) | ||||||
|  |         obs = obs_next | ||||||
|  |         if done: | ||||||
|  |             obs = env.reset(1) | ||||||
|  |     indice = np.arange(len(buf)) | ||||||
|  |     assert abs(buf.get_stack(indice, 'obs') - np.array([ | ||||||
|  |         [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], | ||||||
|  |         [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], | ||||||
|  |         [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6 | ||||||
|  |     print(buf) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     test_replaybuffer() |     test_replaybuffer() | ||||||
|  |     test_stack() | ||||||
|  | |||||||
| @ -26,7 +26,7 @@ def get_args(): | |||||||
|     parser.add_argument('--stack-num', type=int, default=4) |     parser.add_argument('--stack-num', type=int, default=4) | ||||||
|     parser.add_argument('--lr', type=float, default=1e-3) |     parser.add_argument('--lr', type=float, default=1e-3) | ||||||
|     parser.add_argument('--gamma', type=float, default=0.9) |     parser.add_argument('--gamma', type=float, default=0.9) | ||||||
|     parser.add_argument('--n-step', type=int, default=3) |     parser.add_argument('--n-step', type=int, default=4) | ||||||
|     parser.add_argument('--target-update-freq', type=int, default=320) |     parser.add_argument('--target-update-freq', type=int, default=320) | ||||||
|     parser.add_argument('--epoch', type=int, default=100) |     parser.add_argument('--epoch', type=int, default=100) | ||||||
|     parser.add_argument('--step-per-epoch', type=int, default=1000) |     parser.add_argument('--step-per-epoch', type=int, default=1000) | ||||||
|  | |||||||
| @ -31,27 +31,29 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1): | |||||||
| 
 | 
 | ||||||
| def test_fn(size=2560): | def test_fn(size=2560): | ||||||
|     policy = PGPolicy(None, None, None, discount_factor=0.1) |     policy = PGPolicy(None, None, None, discount_factor=0.1) | ||||||
|  |     buf = ReplayBuffer(100) | ||||||
|  |     buf.add(1, 1, 1, 1, 1) | ||||||
|     fn = policy.process_fn |     fn = policy.process_fn | ||||||
|     # fn = compute_return_base |     # fn = compute_return_base | ||||||
|     batch = Batch( |     batch = Batch( | ||||||
|         done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), |         done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), | ||||||
|         rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), |         rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), | ||||||
|     ) |     ) | ||||||
|     batch = fn(batch, None, None) |     batch = fn(batch, buf, 0) | ||||||
|     ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) |     ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) | ||||||
|     assert abs(batch.returns - ans).sum() <= 1e-5 |     assert abs(batch.returns - ans).sum() <= 1e-5 | ||||||
|     batch = Batch( |     batch = Batch( | ||||||
|         done=np.array([0, 1, 0, 1, 0, 1, 0.]), |         done=np.array([0, 1, 0, 1, 0, 1, 0.]), | ||||||
|         rew=np.array([7, 6, 1, 2, 3, 4, 5.]), |         rew=np.array([7, 6, 1, 2, 3, 4, 5.]), | ||||||
|     ) |     ) | ||||||
|     batch = fn(batch, None, None) |     batch = fn(batch, buf, 0) | ||||||
|     ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) |     ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) | ||||||
|     assert abs(batch.returns - ans).sum() <= 1e-5 |     assert abs(batch.returns - ans).sum() <= 1e-5 | ||||||
|     batch = Batch( |     batch = Batch( | ||||||
|         done=np.array([0, 1, 0, 1, 0, 0, 1.]), |         done=np.array([0, 1, 0, 1, 0, 0, 1.]), | ||||||
|         rew=np.array([7, 6, 1, 2, 3, 4, 5.]), |         rew=np.array([7, 6, 1, 2, 3, 4, 5.]), | ||||||
|     ) |     ) | ||||||
|     batch = fn(batch, None, None) |     batch = fn(batch, buf, 0) | ||||||
|     ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) |     ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) | ||||||
|     assert abs(batch.returns - ans).sum() <= 1e-5 |     assert abs(batch.returns - ans).sum() <= 1e-5 | ||||||
|     if __name__ == '__main__': |     if __name__ == '__main__': | ||||||
| @ -66,7 +68,7 @@ def test_fn(size=2560): | |||||||
|         print(f'vanilla: {(time.time() - t) / cnt}') |         print(f'vanilla: {(time.time() - t) / cnt}') | ||||||
|         t = time.time() |         t = time.time() | ||||||
|         for _ in range(cnt): |         for _ in range(cnt): | ||||||
|             policy.process_fn(batch, None, None) |             policy.process_fn(batch, buf, 0) | ||||||
|         print(f'policy: {(time.time() - t) / cnt}') |         print(f'policy: {(time.time() - t) / cnt}') | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -147,5 +149,5 @@ def test_pg(args=get_args()): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     # test_fn() |     test_fn() | ||||||
|     test_pg() |     test_pg() | ||||||
|  | |||||||
| @ -73,6 +73,22 @@ class Batch(object): | |||||||
|                 b.__dict__.update(**{k: self.__dict__[k][index]}) |                 b.__dict__.update(**{k: self.__dict__[k][index]}) | ||||||
|         return b |         return b | ||||||
| 
 | 
 | ||||||
|  |     def __repr__(self): | ||||||
|  |         """Return str(self).""" | ||||||
|  |         s = self.__class__.__name__ + '(\n' | ||||||
|  |         flag = False | ||||||
|  |         for k in self.__dict__.keys(): | ||||||
|  |             if k[0] != '_' and self.__dict__[k] is not None: | ||||||
|  |                 rpl = '\n' + ' ' * (6 + len(k)) | ||||||
|  |                 obj = str(self.__dict__[k]).replace('\n', rpl) | ||||||
|  |                 s += f'    {k}: {obj},\n' | ||||||
|  |                 flag = True | ||||||
|  |         if flag: | ||||||
|  |             s += ')\n' | ||||||
|  |         else: | ||||||
|  |             s = self.__class__.__name__ + '()\n' | ||||||
|  |         return s | ||||||
|  | 
 | ||||||
|     def append(self, batch): |     def append(self, batch): | ||||||
|         """Append a :class:`~tianshou.data.Batch` object to current batch.""" |         """Append a :class:`~tianshou.data.Batch` object to current batch.""" | ||||||
|         assert isinstance(batch, Batch), 'Only append Batch is allowed!' |         assert isinstance(batch, Batch), 'Only append Batch is allowed!' | ||||||
|  | |||||||
| @ -39,6 +39,34 @@ class ReplayBuffer(object): | |||||||
|         >>> batch_data, indice = buf.sample(batch_size=4) |         >>> batch_data, indice = buf.sample(batch_size=4) | ||||||
|         >>> batch_data.obs == buf[indice].obs |         >>> batch_data.obs == buf[indice].obs | ||||||
|         array([ True,  True,  True,  True]) |         array([ True,  True,  True,  True]) | ||||||
|  | 
 | ||||||
|  |     From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports | ||||||
|  |     frame_stack sampling, typically for RNN usage: | ||||||
|  |     :: | ||||||
|  | 
 | ||||||
|  |         >>> buf = ReplayBuffer(size=9, stack_num=4) | ||||||
|  |         >>> for i in range(16): | ||||||
|  |         ...     done = i % 5 == 0 | ||||||
|  |         ...     buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={}) | ||||||
|  |         >>> print(buf.obs) | ||||||
|  |         [ 9. 10. 11. 12. 13. 14. 15. 7. 8.] | ||||||
|  |         >>> print(buf.done) | ||||||
|  |         [0. 1. 0. 0. 0. 0. 1. 0. 0.] | ||||||
|  |         >>> index = np.arange(len(buf)) | ||||||
|  |         >>> print(buf.get_stack(index, 'obs')) | ||||||
|  |         [[ 7.  7.  8.  9.] | ||||||
|  |          [ 7.  8.  9. 10.] | ||||||
|  |          [11. 11. 11. 11.] | ||||||
|  |          [11. 11. 11. 12.] | ||||||
|  |          [11. 11. 12. 13.] | ||||||
|  |          [11. 12. 13. 14.] | ||||||
|  |          [12. 13. 14. 15.] | ||||||
|  |          [ 7.  7.  7.  7.] | ||||||
|  |          [ 7.  7.  7.  8.]] | ||||||
|  |         >>> # here is another way to get the stacked data | ||||||
|  |         >>> # (stack only for obs and obs_next) | ||||||
|  |         >>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs)) | ||||||
|  |         0.0 | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, size, stack_num=0): |     def __init__(self, size, stack_num=0): | ||||||
| @ -51,8 +79,26 @@ class ReplayBuffer(object): | |||||||
|         """Return len(self).""" |         """Return len(self).""" | ||||||
|         return self._size |         return self._size | ||||||
| 
 | 
 | ||||||
|  |     def __repr__(self): | ||||||
|  |         """Return str(self).""" | ||||||
|  |         s = self.__class__.__name__ + '(\n' | ||||||
|  |         flag = False | ||||||
|  |         for k in self.__dict__.keys(): | ||||||
|  |             if k[0] != '_' and self.__dict__[k] is not None: | ||||||
|  |                 rpl = '\n' + ' ' * (6 + len(k)) | ||||||
|  |                 obj = str(self.__dict__[k]).replace('\n', rpl) | ||||||
|  |                 s += f'    {k}: {obj},\n' | ||||||
|  |                 flag = True | ||||||
|  |         if flag: | ||||||
|  |             s += ')\n' | ||||||
|  |         else: | ||||||
|  |             s = self.__class__.__name__ + '()\n' | ||||||
|  |         return s | ||||||
|  | 
 | ||||||
|     def _add_to_buffer(self, name, inst): |     def _add_to_buffer(self, name, inst): | ||||||
|         if inst is None: |         if inst is None: | ||||||
|  |             if getattr(self, name, None) is None: | ||||||
|  |                 self.__dict__[name] = None | ||||||
|             return |             return | ||||||
|         if self.__dict__.get(name, None) is None: |         if self.__dict__.get(name, None) is None: | ||||||
|             if isinstance(inst, np.ndarray): |             if isinstance(inst, np.ndarray): | ||||||
| @ -72,13 +118,14 @@ class ReplayBuffer(object): | |||||||
|         i = begin = buffer._index % len(buffer) |         i = begin = buffer._index % len(buffer) | ||||||
|         while True: |         while True: | ||||||
|             self.add( |             self.add( | ||||||
|                 buffer.obs[i], buffer.act[i], buffer.rew[i], |                 buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], | ||||||
|                 buffer.done[i], buffer.obs_next[i], buffer.info[i]) |                 None if buffer.obs_next is None else buffer.obs_next[i], | ||||||
|  |                 buffer.info[i]) | ||||||
|             i = (i + 1) % len(buffer) |             i = (i + 1) % len(buffer) | ||||||
|             if i == begin: |             if i == begin: | ||||||
|                 break |                 break | ||||||
| 
 | 
 | ||||||
|     def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): |     def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None): | ||||||
|         """Add a batch of data into replay buffer.""" |         """Add a batch of data into replay buffer.""" | ||||||
|         assert isinstance(info, dict), \ |         assert isinstance(info, dict), \ | ||||||
|             'You should return a dict in the last argument of env.step().' |             'You should return a dict in the last argument of env.step().' | ||||||
| @ -97,7 +144,6 @@ class ReplayBuffer(object): | |||||||
|     def reset(self): |     def reset(self): | ||||||
|         """Clear all the data in replay buffer.""" |         """Clear all the data in replay buffer.""" | ||||||
|         self._index = self._size = 0 |         self._index = self._size = 0 | ||||||
|         self.indice = [] |  | ||||||
| 
 | 
 | ||||||
|     def sample(self, batch_size): |     def sample(self, batch_size): | ||||||
|         """Get a random sample from buffer with size equal to batch_size. \ |         """Get a random sample from buffer with size equal to batch_size. \ | ||||||
| @ -114,16 +160,26 @@ class ReplayBuffer(object): | |||||||
|             ]) |             ]) | ||||||
|         return self[indice], indice |         return self[indice], indice | ||||||
| 
 | 
 | ||||||
|     def _get_stack(self, indice, key): |     def get_stack(self, indice, key): | ||||||
|  |         """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], | ||||||
|  |         where s is self.key, t is indice. The stack_num (here equals to 4) is | ||||||
|  |         given from buffer initialization procedure. | ||||||
|  |         """ | ||||||
|         if self.__dict__.get(key, None) is None: |         if self.__dict__.get(key, None) is None: | ||||||
|             return None |             return None | ||||||
|         if self._stack == 0: |         if self._stack == 0: | ||||||
|             return self.__dict__[key][indice] |             return self.__dict__[key][indice] | ||||||
|         stack = [] |         stack = [] | ||||||
|  |         # set last frame done to True | ||||||
|  |         last_index = (self._index - 1 + self._size) % self._size | ||||||
|  |         last_done, self.done[last_index] = self.done[last_index], True | ||||||
|         for i in range(self._stack): |         for i in range(self._stack): | ||||||
|             stack = [self.__dict__[key][indice]] + stack |             stack = [self.__dict__[key][indice]] + stack | ||||||
|             indice = indice - 1 + self.done[indice - 1].astype(np.int) |             pre_indice = indice - 1 | ||||||
|             indice[indice == -1] = self._size - 1 |             pre_indice[pre_indice == -1] = self._size - 1 | ||||||
|  |             indice = pre_indice + self.done[pre_indice].astype(np.int) | ||||||
|  |             indice[indice == self._size] = 0 | ||||||
|  |         self.done[last_index] = last_done | ||||||
|         return np.stack(stack, axis=1) |         return np.stack(stack, axis=1) | ||||||
| 
 | 
 | ||||||
|     def __getitem__(self, index): |     def __getitem__(self, index): | ||||||
| @ -131,11 +187,11 @@ class ReplayBuffer(object): | |||||||
|         return the stacked obs and obs_next with shape [batch, len, ...]. |         return the stacked obs and obs_next with shape [batch, len, ...]. | ||||||
|         """ |         """ | ||||||
|         return Batch( |         return Batch( | ||||||
|             obs=self._get_stack(index, 'obs'), |             obs=self.get_stack(index, 'obs'), | ||||||
|             act=self.act[index], |             act=self.act[index], | ||||||
|             rew=self.rew[index], |             rew=self.rew[index], | ||||||
|             done=self.done[index], |             done=self.done[index], | ||||||
|             obs_next=self._get_stack(index, 'obs_next'), |             obs_next=self.get_stack(index, 'obs_next'), | ||||||
|             info=self.info[index] |             info=self.info[index] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -2,10 +2,10 @@ import time | |||||||
| import torch | import torch | ||||||
| import warnings | import warnings | ||||||
| import numpy as np | import numpy as np | ||||||
| from tianshou.env import BaseVectorEnv | 
 | ||||||
| from tianshou.data import Batch, ReplayBuffer, \ |  | ||||||
|     ListReplayBuffer |  | ||||||
| from tianshou.utils import MovAvg | from tianshou.utils import MovAvg | ||||||
|  | from tianshou.env import BaseVectorEnv | ||||||
|  | from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Collector(object): | class Collector(object): | ||||||
| @ -22,8 +22,8 @@ class Collector(object): | |||||||
|         :class:`~tianshou.data.ReplayBuffer`. |         :class:`~tianshou.data.ReplayBuffer`. | ||||||
|     :param int stat_size: for the moving average of recording speed, defaults |     :param int stat_size: for the moving average of recording speed, defaults | ||||||
|         to 100. |         to 100. | ||||||
|     :param bool store_obs_next: whether to store the obs_next to replay |     :param bool store_obs_next: store the next observation to replay buffer or | ||||||
|         buffer, defaults to ``True``. |         not, defaults to ``True``. | ||||||
| 
 | 
 | ||||||
|     Example: |     Example: | ||||||
|     :: |     :: | ||||||
| @ -302,7 +302,7 @@ class Collector(object): | |||||||
|         self._obs = obs_next |         self._obs = obs_next | ||||||
|         if self._multi_env: |         if self._multi_env: | ||||||
|             cur_episode = sum(cur_episode) |             cur_episode = sum(cur_episode) | ||||||
|         duration = time.time() - start_time |         duration = max(time.time() - start_time, 1e-9) | ||||||
|         self.step_speed.add(cur_step / duration) |         self.step_speed.add(cur_step / duration) | ||||||
|         self.episode_speed.add(cur_episode / duration) |         self.episode_speed.add(cur_episode / duration) | ||||||
|         self.collect_step += cur_step |         self.collect_step += cur_step | ||||||
|  | |||||||
| @ -35,6 +35,8 @@ class PGPolicy(BasePolicy): | |||||||
|         discount factor, :math:`\gamma \in [0, 1]`. |         discount factor, :math:`\gamma \in [0, 1]`. | ||||||
|         """ |         """ | ||||||
|         batch.returns = self._vanilla_returns(batch) |         batch.returns = self._vanilla_returns(batch) | ||||||
|  |         if getattr(batch, 'obs_next', None) is None: | ||||||
|  |             batch.obs_next = buffer[(indice + 1) % len(buffer)].obs | ||||||
|         # batch.returns = self._vectorized_returns(batch) |         # batch.returns = self._vectorized_returns(batch) | ||||||
|         return batch |         return batch | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user