add preprocess_fn (#42)

This commit is contained in:
Trinkle23897 2020-05-05 13:39:51 +08:00
parent 04b091d975
commit 075825325e
5 changed files with 100 additions and 50 deletions

View File

@ -29,13 +29,28 @@ def equal(a, b):
return abs(np.array(a) - np.array(b)).sum() < 1e-6
def preprocess_fn(**kwargs):
# modify info before adding into the buffer
if kwargs.get('info', None) is not None:
n = len(kwargs['obs'])
info = kwargs['info']
for i in range(n):
info[i].update(rew=kwargs['rew'][i])
return {'info': info}
# or
# return Batch(info=info)
else:
return {}
class Logger(object):
def __init__(self, writer):
self.cnt = 0
self.writer = writer
def log(self, info):
self.writer.add_scalar('key', info['key'], global_step=self.cnt)
self.writer.add_scalar(
'key', np.mean(info['key']), global_step=self.cnt)
self.cnt += 1
@ -52,21 +67,24 @@ def test_collector():
venv = SubprocVectorEnv(env_fns)
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn)
c0.collect(n_step=3, log_fn=logger.log)
assert equal(c0.buffer.obs[:3], [0, 1, 0])
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
c0.collect(n_episode=3, log_fn=logger.log)
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn)
c1.collect(n_step=6)
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
c1.collect(n_episode=2)
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn)
c2.collect(n_episode=[1, 2, 2, 2])
assert equal(c2.buffer.obs_next[:26], [
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
@ -81,7 +99,7 @@ def test_collector():
def test_collector_with_dict_state():
env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True)
c0 = Collector(policy, env, ReplayBuffer(size=100))
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
c0.collect(n_step=3)
c0.collect(n_episode=3)
env_fns = [
@ -91,7 +109,7 @@ def test_collector_with_dict_state():
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
]
envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100))
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2])
batch = c1.sample(10)
@ -101,7 +119,8 @@ def test_collector_with_dict_state():
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4))
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
preprocess_fn)
c2.collect(n_episode=[0, 0, 0, 10])
batch = c2.sample(10)
print(batch['obs_next']['index'])

View File

@ -130,6 +130,12 @@ class Batch(object):
return sorted([i for i in self.__dict__ if i[0] != '_'] +
list(self._meta))
def get(self, k, d=None):
"""Return self[k] if k in self else d. d defaults to None."""
if k in self.__dict__ or k in self._meta:
return self.__getattr__(k)
return d
def to_numpy(self):
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
operation.

View File

@ -14,15 +14,24 @@ class Collector(object):
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
:param env: an environment or an instance of the
:param env: a ``gym.Env`` environment or an instance of the
:class:`~tianshou.env.BaseVectorEnv` class.
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
``None``, it will automatically assign a small-size
:class:`~tianshou.data.ReplayBuffer`.
:param function preprocess_fn: a function called before the data has been
added to the buffer, see issue #42, defaults to ``None``.
:param int stat_size: for the moving average of recording speed, defaults
to 100.
The ``preprocess_fn`` is a function called before the data has been added
to the buffer with batch format, which receives up to 7 keys as listed in
:class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
collector resets the environment. It returns either a dict or a
:class:`~tianshou.data.Batch` with the modified keys and values. Examples
are in "test/base/test_collector.py".
Example:
::
@ -68,15 +77,21 @@ class Collector(object):
Please make sure the given environment has a time limitation.
"""
def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs):
def __init__(self, policy, env, buffer=None, preprocess_fn=None,
stat_size=100, **kwargs):
super().__init__()
self.env = env
self.env_num = 1
self.collect_time = 0
self.collect_step = 0
self.collect_episode = 0
self.collect_time = 0
self.buffer = buffer
self.policy = policy
self.preprocess_fn = preprocess_fn
# if preprocess_fn is None:
# def _prep(**kwargs):
# return kwargs
# self.preprocess_fn = _prep
self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv)
self._multi_buf = False # True if buf is a list
@ -119,7 +134,7 @@ class Collector(object):
self.buffer.reset()
def get_env_num(self):
"""Return the number of environments the collector has."""
"""Return the number of environments the collector have."""
return self.env_num
def reset_env(self):
@ -127,6 +142,10 @@ class Collector(object):
buffers (if need).
"""
self._obs = self.env.reset()
if not self._multi_env:
self._obs = self._make_batch(self._obs)
if self.preprocess_fn:
self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs)
self._act = self._rew = self._done = self._info = None
if self._multi_env:
self.reward = np.zeros(self.env_num)
@ -231,40 +250,43 @@ class Collector(object):
'There are already many steps in an episode. '
'You should add a time limitation to your environment!',
Warning)
if self._multi_env:
batch_data = Batch(
obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=None, info=self._info)
else:
batch_data = Batch(
obs=self._make_batch(self._obs),
act=self._make_batch(self._act),
rew=self._make_batch(self._rew),
done=self._make_batch(self._done),
obs_next=None,
info=self._make_batch(self._info))
batch = Batch(
obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=None, info=self._info,
policy=None)
with torch.no_grad():
result = self.policy(batch_data, self.state)
self.state = result.state if hasattr(result, 'state') else None
result = self.policy(batch, self.state)
self.state = result.get('state', None)
self._policy = self._to_numpy(result.policy) \
if hasattr(result, 'policy') \
else [{}] * self.env_num if self._multi_env else {}
if isinstance(result.act, torch.Tensor):
self._act = self._to_numpy(result.act)
elif not isinstance(self._act, np.ndarray):
self._act = np.array(result.act)
else:
self._act = result.act
if hasattr(result, 'policy') else [{}] * self.env_num
self._act = self._to_numpy(result.act)
obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0])
if log_fn is not None:
log_fn(self._info)
if render is not None:
if not self._multi_env:
obs_next = self._make_batch(obs_next)
self._rew = self._make_batch(self._rew)
self._done = self._make_batch(self._done)
self._info = self._make_batch(self._info)
if log_fn:
log_fn(self._info if self._multi_env else self._info[0])
if render:
self.env.render()
if render > 0:
time.sleep(render)
self.length += 1
self.reward += self._rew
if self.preprocess_fn:
result = self.preprocess_fn(
obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=obs_next, info=self._info,
policy=self._policy)
self._obs = result.get('obs', self._obs)
self._act = result.get('act', self._act)
self._rew = result.get('rew', self._rew)
self._done = result.get('done', self._done)
obs_next = result.get('obs_next', obs_next)
self._info = result.get('info', self._info)
self._policy = result.get('policy', self._policy)
if self._multi_env:
for i in range(self.env_num):
data = {
@ -300,6 +322,9 @@ class Collector(object):
self._reset_state(i)
if sum(self._done):
obs_next = self.env.reset(np.where(self._done)[0])
if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get(
'obs', obs_next)
if n_episode != 0:
if isinstance(n_episode, list) and \
(cur_episode >= np.array(n_episode)).all() or \
@ -309,16 +334,20 @@ class Collector(object):
else:
if self.buffer is not None:
self.buffer.add(
self._obs, self._act[0], self._rew,
self._done, obs_next, self._info, self._policy)
self._obs[0], self._act[0], self._rew[0],
self._done[0], obs_next[0], self._info[0],
self._policy[0])
cur_step += 1
if self._done:
cur_episode += 1
reward_sum += self.reward
reward_sum += self.reward[0]
length_sum += self.length
self.reward, self.length = 0, 0
self.state = None
obs_next = self.env.reset()
obs_next = self._make_batch(self.env.reset())
if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get(
'obs', obs_next)
if n_episode != 0 and cur_episode >= n_episode:
break
if n_step != 0 and cur_step >= n_step:

View File

@ -10,7 +10,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
batch_size,
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
log_fn=None, writer=None, log_interval=1, verbose=True,
task='', **kwargs):
**kwargs):
"""A wrapper for off-policy trainer procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
@ -89,8 +89,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
data[k] = f'{result[k]:.2f}'
if writer and global_step % log_interval == 0:
writer.add_scalar(
k + '_' + task if task else k,
result[k], global_step=global_step)
k, result[k], global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
@ -98,8 +97,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
data[k] = f'{stat[k].get():.6f}'
if writer and global_step % log_interval == 0:
writer.add_scalar(
k + '_' + task if task else k,
stat[k].get(), global_step=global_step)
k, stat[k].get(), global_step=global_step)
t.update(1)
t.set_postfix(**data)
if t.n <= t.total:

View File

@ -10,7 +10,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
episode_per_test, batch_size,
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
log_fn=None, writer=None, log_interval=1, verbose=True,
task='', **kwargs):
**kwargs):
"""A wrapper for on-policy trainer procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
@ -97,8 +97,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
data[k] = f'{result[k]:.2f}'
if writer and global_step % log_interval == 0:
writer.add_scalar(
k + '_' + task if task else k,
result[k], global_step=global_step)
k, result[k], global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
@ -106,8 +105,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
data[k] = f'{stat[k].get():.6f}'
if writer and global_step % log_interval == 0:
writer.add_scalar(
k + '_' + task if task else k,
stat[k].get(), global_step=global_step)
k, stat[k].get(), global_step=global_step)
t.update(step)
t.set_postfix(**data)
if t.n <= t.total: