add cache buf in collector
This commit is contained in:
		
							parent
							
								
									543e57cdbd
								
							
						
					
					
						commit
						c804662457
					
				
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @ -33,7 +33,7 @@ setup( | |||||||
|         'Programming Language :: Python :: 3.8', |         'Programming Language :: Python :: 3.8', | ||||||
|     ], |     ], | ||||||
|     keywords='reinforcement learning platform', |     keywords='reinforcement learning platform', | ||||||
|     packages=find_packages(exclude=['tests', 'tests.*', |     packages=find_packages(exclude=['test', 'test.*', | ||||||
|                                     'examples', 'examples.*', |                                     'examples', 'examples.*', | ||||||
|                                     'docs', 'docs.*']), |                                     'docs', 'docs.*']), | ||||||
|     install_requires=[ |     install_requires=[ | ||||||
|  | |||||||
| @ -14,7 +14,7 @@ def test_replaybuffer(size=10, bufsize=20): | |||||||
|         obs_next, rew, done, info = env.step(a) |         obs_next, rew, done, info = env.step(a) | ||||||
|         buf.add(obs, a, rew, done, obs_next, info) |         buf.add(obs, a, rew, done, obs_next, info) | ||||||
|         assert len(buf) == min(bufsize, i + 1), print(len(buf), i) |         assert len(buf) == min(bufsize, i + 1), print(len(buf), i) | ||||||
|     data, indice = buf.sample(4) |     data, indice = buf.sample(bufsize * 2) | ||||||
|     assert (indice < len(buf)).all() |     assert (indice < len(buf)).all() | ||||||
|     assert (data.obs < size).all() |     assert (data.obs < size).all() | ||||||
|     assert (0 <= data.done).all() and (data.done <= 1).all() |     assert (0 <= data.done).all() and (data.done <= 1).all() | ||||||
|  | |||||||
| @ -1,3 +1,4 @@ | |||||||
|  | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -8,6 +9,13 @@ class Batch(object): | |||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.__dict__.update(kwargs) |         self.__dict__.update(kwargs) | ||||||
| 
 | 
 | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         b = Batch() | ||||||
|  |         for k in self.__dict__.keys(): | ||||||
|  |             if self.__dict__[k] is not None: | ||||||
|  |                 b.update(k=self.__dict__[k][index]) | ||||||
|  |         return b | ||||||
|  | 
 | ||||||
|     def update(self, **kwargs): |     def update(self, **kwargs): | ||||||
|         self.__dict__.update(kwargs) |         self.__dict__.update(kwargs) | ||||||
| 
 | 
 | ||||||
| @ -21,9 +29,12 @@ class Batch(object): | |||||||
|             elif isinstance(batch.__dict__[k], np.ndarray): |             elif isinstance(batch.__dict__[k], np.ndarray): | ||||||
|                 self.__dict__[k] = np.concatenate([ |                 self.__dict__[k] = np.concatenate([ | ||||||
|                     self.__dict__[k], batch.__dict__[k]]) |                     self.__dict__[k], batch.__dict__[k]]) | ||||||
|  |             elif isinstance(batch.__dict__[k], torch.Tensor): | ||||||
|  |                 self.__dict__[k] = torch.cat([ | ||||||
|  |                     self.__dict__[k], batch.__dict__[k]]) | ||||||
|             elif isinstance(batch.__dict__[k], list): |             elif isinstance(batch.__dict__[k], list): | ||||||
|                 self.__dict__[k] += batch.__dict__[k] |                 self.__dict__[k] += batch.__dict__[k] | ||||||
|             else: |             else: | ||||||
|                 raise TypeError( |                 raise TypeError( | ||||||
|                     'Do not support append with type {} in class Batch.' |                     'No support for append with type {} in class Batch.' | ||||||
|                     .format(type(batch.__dict__[k]))) |                     .format(type(batch.__dict__[k]))) | ||||||
|  | |||||||
| @ -26,6 +26,12 @@ class ReplayBuffer(object): | |||||||
|                 self.__dict__[name] = np.zeros([self._maxsize]) |                 self.__dict__[name] = np.zeros([self._maxsize]) | ||||||
|         self.__dict__[name][self._index] = inst |         self.__dict__[name][self._index] = inst | ||||||
| 
 | 
 | ||||||
|  |     def update(self, buffer): | ||||||
|  |         for i in range(len(buffer)): | ||||||
|  |             self.add( | ||||||
|  |                 buffer.obs[i], buffer.act[i], buffer.rew[i], | ||||||
|  |                 buffer.done[i], buffer.obs_next[i], buffer.info[i]) | ||||||
|  | 
 | ||||||
|     def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): |     def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): | ||||||
|         ''' |         ''' | ||||||
|         weight: importance weights, disabled here |         weight: importance weights, disabled here | ||||||
|  | |||||||
| @ -1,3 +1,4 @@ | |||||||
|  | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| 
 | 
 | ||||||
| @ -9,21 +10,26 @@ from tianshou.utils import MovAvg | |||||||
| class Collector(object): | class Collector(object): | ||||||
|     """docstring for Collector""" |     """docstring for Collector""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, policy, env, buffer): |     def __init__(self, policy, env, buffer, contiguous=True): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.env = env |         self.env = env | ||||||
|         self.env_num = 1 |         self.env_num = 1 | ||||||
|         self.buffer = buffer |         self.buffer = buffer | ||||||
|         self.policy = policy |         self.policy = policy | ||||||
|         self.process_fn = policy.process_fn |         self.process_fn = policy.process_fn | ||||||
|         self.multi_env = isinstance(env, BaseVectorEnv) |         self._multi_env = isinstance(env, BaseVectorEnv) | ||||||
|         if self.multi_env: |         self._multi_buf = False  # buf is a list | ||||||
|  |         # need multiple cache buffers only if contiguous in one buffer | ||||||
|  |         self._cached_buf = [] | ||||||
|  |         if self._multi_env: | ||||||
|             self.env_num = len(env) |             self.env_num = len(env) | ||||||
|             if isinstance(self.buffer, list): |             if isinstance(self.buffer, list): | ||||||
|                 assert len(self.buffer) == self.env_num,\ |                 assert len(self.buffer) == self.env_num,\ | ||||||
|                     '# of data buffer does not match the # of input env.' |                     '# of data buffer does not match the # of input env.' | ||||||
|             elif isinstance(self.buffer, ReplayBuffer): |                 self._multi_buf = True | ||||||
|                 self.buffer = [deepcopy(buffer) for _ in range(self.env_num)] |             elif isinstance(self.buffer, ReplayBuffer) and contiguous: | ||||||
|  |                 self._cached_buf = [ | ||||||
|  |                     deepcopy(buffer) for _ in range(self.env_num)] | ||||||
|             else: |             else: | ||||||
|                 raise TypeError('The buffer in data collector is invalid!') |                 raise TypeError('The buffer in data collector is invalid!') | ||||||
|         self.reset_env() |         self.reset_env() | ||||||
| @ -34,7 +40,7 @@ class Collector(object): | |||||||
|         self.stat_length = MovAvg() |         self.stat_length = MovAvg() | ||||||
| 
 | 
 | ||||||
|     def clear_buffer(self): |     def clear_buffer(self): | ||||||
|         if self.multi_env: |         if self._multi_buf: | ||||||
|             for b in self.buffer: |             for b in self.buffer: | ||||||
|                 b.reset() |                 b.reset() | ||||||
|         else: |         else: | ||||||
| @ -43,17 +49,24 @@ class Collector(object): | |||||||
|     def reset_env(self): |     def reset_env(self): | ||||||
|         self._obs = self.env.reset() |         self._obs = self.env.reset() | ||||||
|         self._act = self._rew = self._done = self._info = None |         self._act = self._rew = self._done = self._info = None | ||||||
|         if self.multi_env: |         if self._multi_env: | ||||||
|             self.reward = np.zeros(self.env_num) |             self.reward = np.zeros(self.env_num) | ||||||
|             self.length = np.zeros(self.env_num) |             self.length = np.zeros(self.env_num) | ||||||
|         else: |         else: | ||||||
|             self.reward, self.length = 0, 0 |             self.reward, self.length = 0, 0 | ||||||
|  |         for b in self._cached_buf: | ||||||
|  |             b.reset() | ||||||
|  | 
 | ||||||
|  |     def _make_batch(data): | ||||||
|  |         if isinstance(data, np.ndarray): | ||||||
|  |             return data[None] | ||||||
|  |         else: | ||||||
|  |             return [data] | ||||||
| 
 | 
 | ||||||
|     def collect(self, n_step=0, n_episode=0): |     def collect(self, n_step=0, n_episode=0): | ||||||
|         assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ |         assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ | ||||||
|             "One and only one collection number specification permitted!" |             "One and only one collection number specification permitted!" | ||||||
|         cur_step = 0 |         cur_step, cur_episode = 0, 0 | ||||||
|         cur_episode = np.zeros(self.env_num) if self.multi_env else 0 |  | ||||||
|         while True: |         while True: | ||||||
|             if self.multi_env: |             if self.multi_env: | ||||||
|                 batch_data = Batch( |                 batch_data = Batch( | ||||||
| @ -61,41 +74,55 @@ class Collector(object): | |||||||
|                     done=self._done, obs_next=None, info=self._info) |                     done=self._done, obs_next=None, info=self._info) | ||||||
|             else: |             else: | ||||||
|                 batch_data = Batch( |                 batch_data = Batch( | ||||||
|                     obs=[self._obs], act=[self._act], rew=[self._rew], |                     obs=self._make_batch(self._obs), | ||||||
|                     done=[self._done], obs_next=None, info=[self._info]) |                     act=self._make_batch(self._act), | ||||||
|  |                     rew=self._make_batch(self._rew), | ||||||
|  |                     done=self._make_batch(self._done), | ||||||
|  |                     obs_next=None, info=self._make_batch(self._info)) | ||||||
|             result = self.policy.act(batch_data, self.state) |             result = self.policy.act(batch_data, self.state) | ||||||
|             self.state = result.state |             self.state = result.state if hasattr(result, 'state') else None | ||||||
|             self._act = result.act |             self._act = result.act | ||||||
|             obs_next, self._rew, self._done, self._info = self.env.step( |             obs_next, self._rew, self._done, self._info = self.env.step( | ||||||
|                 self._act) |                 self._act) | ||||||
|             cur_step += 1 |  | ||||||
|             self.length += 1 |             self.length += 1 | ||||||
|             self.reward += self._rew |             self.reward += self._rew | ||||||
|             if self.multi_env: |             if self._multi_env: | ||||||
|                 for i in range(self.env_num): |                 for i in range(self.env_num): | ||||||
|                     if n_episode > 0 and \ |                     data = { | ||||||
|                             cur_episode[i] < n_episode or n_episode == 0: |                         'obs': self._obs[i], 'act': self._act[i], | ||||||
|                         self.buffer[i].add( |                         'rew': self._rew[i], 'done': self._done[i], | ||||||
|                             self._obs[i], self._act[i], self._rew[i], |                         'obs_next': obs_next[i], 'info': self._info[i]} | ||||||
|                             self._done[i], obs_next[i], self._info[i]) |                     if self._cached_buf: | ||||||
|  |                         self._cached_buf[i].add(**data) | ||||||
|  |                     elif self._multi_buf: | ||||||
|  |                         self.buffer[i].add(**data) | ||||||
|  |                         cur_step += 1 | ||||||
|  |                     else: | ||||||
|  |                         self.buffer.add(**data) | ||||||
|  |                         cur_step += 1 | ||||||
|                     if self._done[i]: |                     if self._done[i]: | ||||||
|                             cur_episode[i] += 1 |                         cur_episode += 1 | ||||||
|                         self.stat_reward.add(self.reward[i]) |                         self.stat_reward.add(self.reward[i]) | ||||||
|                         self.stat_length.add(self.length[i]) |                         self.stat_length.add(self.length[i]) | ||||||
|                         self.reward[i], self.length[i] = 0, 0 |                         self.reward[i], self.length[i] = 0, 0 | ||||||
|  |                         if self._cached_buf: | ||||||
|  |                             self.buffer.update(self._cached_buf[i]) | ||||||
|  |                             cur_step += len(self._cached_buf[i]) | ||||||
|  |                             self._cached_buf[i].reset() | ||||||
|                         if isinstance(self.state, list): |                         if isinstance(self.state, list): | ||||||
|                             self.state[i] = None |                             self.state[i] = None | ||||||
|                         else: |                         else: | ||||||
|                             self.state[i] = self.state[i] * 0 |                             self.state[i] = self.state[i] * 0 | ||||||
|                                 if hasattr(self.state, 'detach'): |                             if isinstance(self.state, torch.Tensor): | ||||||
|                                     # remove ref in torch |                                 # remove ref in torch (?) | ||||||
|                                 self.state = self.state.detach() |                                 self.state = self.state.detach() | ||||||
|                 if n_episode > 0 and (cur_episode >= n_episode).all(): |                 if n_episode > 0 and cur_episode >= n_episode: | ||||||
|                     break |                     break | ||||||
|             else: |             else: | ||||||
|                 self.buffer.add( |                 self.buffer.add( | ||||||
|                     self._obs, self._act[0], self._rew, |                     self._obs, self._act[0], self._rew, | ||||||
|                     self._done, obs_next, self._info) |                     self._done, obs_next, self._info) | ||||||
|  |                 cur_step += 1 | ||||||
|                 if self._done: |                 if self._done: | ||||||
|                     cur_episode += 1 |                     cur_episode += 1 | ||||||
|                     self.stat_reward.add(self.reward) |                     self.stat_reward.add(self.reward) | ||||||
| @ -110,7 +137,7 @@ class Collector(object): | |||||||
|         self._obs = obs_next |         self._obs = obs_next | ||||||
| 
 | 
 | ||||||
|     def sample(self, batch_size): |     def sample(self, batch_size): | ||||||
|         if self.multi_env: |         if self._multi_buf: | ||||||
|             if batch_size > 0: |             if batch_size > 0: | ||||||
|                 lens = [len(b) for b in self.buffer] |                 lens = [len(b) for b in self.buffer] | ||||||
|                 total = sum(lens) |                 total = sum(lens) | ||||||
|  | |||||||
| @ -21,9 +21,11 @@ class BasePolicy(ABC): | |||||||
|     def reset(self): |     def reset(self): | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     def process_fn(self, batch, buffer, indice): | ||||||
|     def process_fn(batch, buffer, indice): |  | ||||||
|         return batch |         return batch | ||||||
| 
 | 
 | ||||||
|  |     def sync_weights(self): | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|     def exploration(self): |     def exploration(self): | ||||||
|         pass |         pass | ||||||
|  | |||||||
							
								
								
									
										33
									
								
								tianshou/policy/dqn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								tianshou/policy/dqn.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,33 @@ | |||||||
|  | import torch | ||||||
|  | from torch import nn | ||||||
|  | from copy import deepcopy | ||||||
|  | 
 | ||||||
|  | from tianshou.data import Batch | ||||||
|  | from tianshou.policy import BasePolicy | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DQNPolicy(BasePolicy, nn.Module): | ||||||
|  |     """docstring for DQNPolicy""" | ||||||
|  | 
 | ||||||
|  |     def __init__(self, model, discount_factor=0.99, estimation_step=1, | ||||||
|  |                  use_target_network=True): | ||||||
|  |         super().__init__() | ||||||
|  |         self.model = model | ||||||
|  |         self._gamma = discount_factor | ||||||
|  |         self._n_step = estimation_step | ||||||
|  |         self._target = use_target_network | ||||||
|  |         if use_target_network: | ||||||
|  |             self.model_old = deepcopy(self.model) | ||||||
|  | 
 | ||||||
|  |     def act(self, batch, hidden_state=None): | ||||||
|  |         batch_result = Batch() | ||||||
|  |         return batch_result | ||||||
|  | 
 | ||||||
|  |     def sync_weights(self): | ||||||
|  |         if self._use_target_network: | ||||||
|  |             for old, new in zip( | ||||||
|  |                     self.model_old.parameters(), self.model.parameters()): | ||||||
|  |                 old.data.copy_(new.data) | ||||||
|  | 
 | ||||||
|  |     def process_fn(self, batch, buffer, indice): | ||||||
|  |         return batch | ||||||
| @ -1,3 +1,4 @@ | |||||||
|  | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -8,8 +9,7 @@ class MovAvg(object): | |||||||
|         self.cache = [] |         self.cache = [] | ||||||
| 
 | 
 | ||||||
|     def add(self, x): |     def add(self, x): | ||||||
|         if hasattr(x, 'detach'): |         if isinstance(x, torch.Tensor): | ||||||
|             # which means x is torch.Tensor (?) |  | ||||||
|             x = x.detach().cpu().numpy() |             x = x.detach().cpu().numpy() | ||||||
|         if x != np.inf: |         if x != np.inf: | ||||||
|             self.cache.append(x) |             self.cache.append(x) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user