From 075825325e8658cdef8a8e96f204c85bd748b3d9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 5 May 2020 13:39:51 +0800 Subject: [PATCH] add preprocess_fn (#42) --- test/base/test_collector.py | 33 +++++++++--- tianshou/data/batch.py | 6 +++ tianshou/data/collector.py | 95 +++++++++++++++++++++++------------ tianshou/trainer/offpolicy.py | 8 ++- tianshou/trainer/onpolicy.py | 8 ++- 5 files changed, 100 insertions(+), 50 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 841c178..9e40db2 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -29,13 +29,28 @@ def equal(a, b): return abs(np.array(a) - np.array(b)).sum() < 1e-6 +def preprocess_fn(**kwargs): + # modify info before adding into the buffer + if kwargs.get('info', None) is not None: + n = len(kwargs['obs']) + info = kwargs['info'] + for i in range(n): + info[i].update(rew=kwargs['rew'][i]) + return {'info': info} + # or + # return Batch(info=info) + else: + return {} + + class Logger(object): def __init__(self, writer): self.cnt = 0 self.writer = writer def log(self, info): - self.writer.add_scalar('key', info['key'], global_step=self.cnt) + self.writer.add_scalar( + 'key', np.mean(info['key']), global_step=self.cnt) self.cnt += 1 @@ -52,21 +67,24 @@ def test_collector(): venv = SubprocVectorEnv(env_fns) policy = MyPolicy() env = env_fns[0]() - c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False)) + c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), + preprocess_fn) c0.collect(n_step=3, log_fn=logger.log) assert equal(c0.buffer.obs[:3], [0, 1, 0]) assert equal(c0.buffer[:3].obs_next, [1, 2, 1]) c0.collect(n_episode=3, log_fn=logger.log) assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) - c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False)) + c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), + preprocess_fn) c1.collect(n_step=6) assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) c1.collect(n_episode=2) assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) - c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False)) + c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), + preprocess_fn) c2.collect(n_episode=[1, 2, 2, 2]) assert equal(c2.buffer.obs_next[:26], [ 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, @@ -81,7 +99,7 @@ def test_collector(): def test_collector_with_dict_state(): env = MyTestEnv(size=5, sleep=0, dict_state=True) policy = MyPolicy(dict_state=True) - c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn) c0.collect(n_step=3) c0.collect(n_episode=3) env_fns = [ @@ -91,7 +109,7 @@ def test_collector_with_dict_state(): lambda: MyTestEnv(size=5, sleep=0, dict_state=True), ] envs = VectorEnv(env_fns) - c1 = Collector(policy, envs, ReplayBuffer(size=100)) + c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) batch = c1.sample(10) @@ -101,7 +119,8 @@ def test_collector_with_dict_state(): 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) - c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4)) + c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), + preprocess_fn) c2.collect(n_episode=[0, 0, 0, 10]) batch = c2.sample(10) print(batch['obs_next']['index']) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 88169f9..de599b1 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -130,6 +130,12 @@ class Batch(object): return sorted([i for i in self.__dict__ if i[0] != '_'] + list(self._meta)) + def get(self, k, d=None): + """Return self[k] if k in self else d. d defaults to None.""" + if k in self.__dict__ or k in self._meta: + return self.__getattr__(k) + return d + def to_numpy(self): """Change all torch.Tensor to numpy.ndarray. This is an inplace operation. diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 9b5478b..bcb920f 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -14,15 +14,24 @@ class Collector(object): :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param env: an environment or an instance of the + :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to ``None``, it will automatically assign a small-size :class:`~tianshou.data.ReplayBuffer`. + :param function preprocess_fn: a function called before the data has been + added to the buffer, see issue #42, defaults to ``None``. :param int stat_size: for the moving average of recording speed, defaults to 100. + The ``preprocess_fn`` is a function called before the data has been added + to the buffer with batch format, which receives up to 7 keys as listed in + :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the + collector resets the environment. It returns either a dict or a + :class:`~tianshou.data.Batch` with the modified keys and values. Examples + are in "test/base/test_collector.py". + Example: :: @@ -68,15 +77,21 @@ class Collector(object): Please make sure the given environment has a time limitation. """ - def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs): + def __init__(self, policy, env, buffer=None, preprocess_fn=None, + stat_size=100, **kwargs): super().__init__() self.env = env self.env_num = 1 + self.collect_time = 0 self.collect_step = 0 self.collect_episode = 0 - self.collect_time = 0 self.buffer = buffer self.policy = policy + self.preprocess_fn = preprocess_fn + # if preprocess_fn is None: + # def _prep(**kwargs): + # return kwargs + # self.preprocess_fn = _prep self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # True if buf is a list @@ -119,7 +134,7 @@ class Collector(object): self.buffer.reset() def get_env_num(self): - """Return the number of environments the collector has.""" + """Return the number of environments the collector have.""" return self.env_num def reset_env(self): @@ -127,6 +142,10 @@ class Collector(object): buffers (if need). """ self._obs = self.env.reset() + if not self._multi_env: + self._obs = self._make_batch(self._obs) + if self.preprocess_fn: + self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs) self._act = self._rew = self._done = self._info = None if self._multi_env: self.reward = np.zeros(self.env_num) @@ -231,40 +250,43 @@ class Collector(object): 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) - if self._multi_env: - batch_data = Batch( - obs=self._obs, act=self._act, rew=self._rew, - done=self._done, obs_next=None, info=self._info) - else: - batch_data = Batch( - obs=self._make_batch(self._obs), - act=self._make_batch(self._act), - rew=self._make_batch(self._rew), - done=self._make_batch(self._done), - obs_next=None, - info=self._make_batch(self._info)) + batch = Batch( + obs=self._obs, act=self._act, rew=self._rew, + done=self._done, obs_next=None, info=self._info, + policy=None) with torch.no_grad(): - result = self.policy(batch_data, self.state) - self.state = result.state if hasattr(result, 'state') else None + result = self.policy(batch, self.state) + self.state = result.get('state', None) self._policy = self._to_numpy(result.policy) \ - if hasattr(result, 'policy') \ - else [{}] * self.env_num if self._multi_env else {} - if isinstance(result.act, torch.Tensor): - self._act = self._to_numpy(result.act) - elif not isinstance(self._act, np.ndarray): - self._act = np.array(result.act) - else: - self._act = result.act + if hasattr(result, 'policy') else [{}] * self.env_num + self._act = self._to_numpy(result.act) obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) - if log_fn is not None: - log_fn(self._info) - if render is not None: + if not self._multi_env: + obs_next = self._make_batch(obs_next) + self._rew = self._make_batch(self._rew) + self._done = self._make_batch(self._done) + self._info = self._make_batch(self._info) + if log_fn: + log_fn(self._info if self._multi_env else self._info[0]) + if render: self.env.render() if render > 0: time.sleep(render) self.length += 1 self.reward += self._rew + if self.preprocess_fn: + result = self.preprocess_fn( + obs=self._obs, act=self._act, rew=self._rew, + done=self._done, obs_next=obs_next, info=self._info, + policy=self._policy) + self._obs = result.get('obs', self._obs) + self._act = result.get('act', self._act) + self._rew = result.get('rew', self._rew) + self._done = result.get('done', self._done) + obs_next = result.get('obs_next', obs_next) + self._info = result.get('info', self._info) + self._policy = result.get('policy', self._policy) if self._multi_env: for i in range(self.env_num): data = { @@ -300,6 +322,9 @@ class Collector(object): self._reset_state(i) if sum(self._done): obs_next = self.env.reset(np.where(self._done)[0]) + if self.preprocess_fn: + obs_next = self.preprocess_fn(obs=obs_next).get( + 'obs', obs_next) if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ @@ -309,16 +334,20 @@ class Collector(object): else: if self.buffer is not None: self.buffer.add( - self._obs, self._act[0], self._rew, - self._done, obs_next, self._info, self._policy) + self._obs[0], self._act[0], self._rew[0], + self._done[0], obs_next[0], self._info[0], + self._policy[0]) cur_step += 1 if self._done: cur_episode += 1 - reward_sum += self.reward + reward_sum += self.reward[0] length_sum += self.length self.reward, self.length = 0, 0 self.state = None - obs_next = self.env.reset() + obs_next = self._make_batch(self.env.reset()) + if self.preprocess_fn: + obs_next = self.preprocess_fn(obs=obs_next).get( + 'obs', obs_next) if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 5648b6c..dfc0159 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -10,7 +10,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, batch_size, train_fn=None, test_fn=None, stop_fn=None, save_fn=None, log_fn=None, writer=None, log_interval=1, verbose=True, - task='', **kwargs): + **kwargs): """A wrapper for off-policy trainer procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` @@ -89,8 +89,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, data[k] = f'{result[k]:.2f}' if writer and global_step % log_interval == 0: writer.add_scalar( - k + '_' + task if task else k, - result[k], global_step=global_step) + k, result[k], global_step=global_step) for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() @@ -98,8 +97,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, data[k] = f'{stat[k].get():.6f}' if writer and global_step % log_interval == 0: writer.add_scalar( - k + '_' + task if task else k, - stat[k].get(), global_step=global_step) + k, stat[k].get(), global_step=global_step) t.update(1) t.set_postfix(**data) if t.n <= t.total: diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 14fe90e..44218ec 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -10,7 +10,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, episode_per_test, batch_size, train_fn=None, test_fn=None, stop_fn=None, save_fn=None, log_fn=None, writer=None, log_interval=1, verbose=True, - task='', **kwargs): + **kwargs): """A wrapper for on-policy trainer procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` @@ -97,8 +97,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, data[k] = f'{result[k]:.2f}' if writer and global_step % log_interval == 0: writer.add_scalar( - k + '_' + task if task else k, - result[k], global_step=global_step) + k, result[k], global_step=global_step) for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() @@ -106,8 +105,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, data[k] = f'{stat[k].get():.6f}' if writer and global_step % log_interval == 0: writer.add_scalar( - k + '_' + task if task else k, - stat[k].get(), global_step=global_step) + k, stat[k].get(), global_step=global_step) t.update(step) t.set_postfix(**data) if t.n <= t.total: