add preprocess_fn (#42)
This commit is contained in:
parent
04b091d975
commit
075825325e
@ -29,13 +29,28 @@ def equal(a, b):
|
||||
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
||||
|
||||
|
||||
def preprocess_fn(**kwargs):
|
||||
# modify info before adding into the buffer
|
||||
if kwargs.get('info', None) is not None:
|
||||
n = len(kwargs['obs'])
|
||||
info = kwargs['info']
|
||||
for i in range(n):
|
||||
info[i].update(rew=kwargs['rew'][i])
|
||||
return {'info': info}
|
||||
# or
|
||||
# return Batch(info=info)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, writer):
|
||||
self.cnt = 0
|
||||
self.writer = writer
|
||||
|
||||
def log(self, info):
|
||||
self.writer.add_scalar('key', info['key'], global_step=self.cnt)
|
||||
self.writer.add_scalar(
|
||||
'key', np.mean(info['key']), global_step=self.cnt)
|
||||
self.cnt += 1
|
||||
|
||||
|
||||
@ -52,21 +67,24 @@ def test_collector():
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
env = env_fns[0]()
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
preprocess_fn)
|
||||
c0.collect(n_step=3, log_fn=logger.log)
|
||||
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
||||
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
|
||||
c0.collect(n_episode=3, log_fn=logger.log)
|
||||
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
preprocess_fn)
|
||||
c1.collect(n_step=6)
|
||||
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
||||
assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
||||
c1.collect(n_episode=2)
|
||||
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
||||
assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
preprocess_fn)
|
||||
c2.collect(n_episode=[1, 2, 2, 2])
|
||||
assert equal(c2.buffer.obs_next[:26], [
|
||||
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||
@ -81,7 +99,7 @@ def test_collector():
|
||||
def test_collector_with_dict_state():
|
||||
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
||||
policy = MyPolicy(dict_state=True)
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100))
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
|
||||
c0.collect(n_step=3)
|
||||
c0.collect(n_episode=3)
|
||||
env_fns = [
|
||||
@ -91,7 +109,7 @@ def test_collector_with_dict_state():
|
||||
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
|
||||
]
|
||||
envs = VectorEnv(env_fns)
|
||||
c1 = Collector(policy, envs, ReplayBuffer(size=100))
|
||||
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
|
||||
c1.collect(n_step=10)
|
||||
c1.collect(n_episode=[2, 1, 1, 2])
|
||||
batch = c1.sample(10)
|
||||
@ -101,7 +119,8 @@ def test_collector_with_dict_state():
|
||||
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
||||
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
||||
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
|
||||
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4))
|
||||
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
||||
preprocess_fn)
|
||||
c2.collect(n_episode=[0, 0, 0, 10])
|
||||
batch = c2.sample(10)
|
||||
print(batch['obs_next']['index'])
|
||||
|
@ -130,6 +130,12 @@ class Batch(object):
|
||||
return sorted([i for i in self.__dict__ if i[0] != '_'] +
|
||||
list(self._meta))
|
||||
|
||||
def get(self, k, d=None):
|
||||
"""Return self[k] if k in self else d. d defaults to None."""
|
||||
if k in self.__dict__ or k in self._meta:
|
||||
return self.__getattr__(k)
|
||||
return d
|
||||
|
||||
def to_numpy(self):
|
||||
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
||||
operation.
|
||||
|
@ -14,15 +14,24 @@ class Collector(object):
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
class.
|
||||
:param env: an environment or an instance of the
|
||||
:param env: a ``gym.Env`` environment or an instance of the
|
||||
:class:`~tianshou.env.BaseVectorEnv` class.
|
||||
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
|
||||
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
|
||||
``None``, it will automatically assign a small-size
|
||||
:class:`~tianshou.data.ReplayBuffer`.
|
||||
:param function preprocess_fn: a function called before the data has been
|
||||
added to the buffer, see issue #42, defaults to ``None``.
|
||||
:param int stat_size: for the moving average of recording speed, defaults
|
||||
to 100.
|
||||
|
||||
The ``preprocess_fn`` is a function called before the data has been added
|
||||
to the buffer with batch format, which receives up to 7 keys as listed in
|
||||
:class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
|
||||
collector resets the environment. It returns either a dict or a
|
||||
:class:`~tianshou.data.Batch` with the modified keys and values. Examples
|
||||
are in "test/base/test_collector.py".
|
||||
|
||||
Example:
|
||||
::
|
||||
|
||||
@ -68,15 +77,21 @@ class Collector(object):
|
||||
Please make sure the given environment has a time limitation.
|
||||
"""
|
||||
|
||||
def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs):
|
||||
def __init__(self, policy, env, buffer=None, preprocess_fn=None,
|
||||
stat_size=100, **kwargs):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
self.collect_time = 0
|
||||
self.collect_step = 0
|
||||
self.collect_episode = 0
|
||||
self.collect_time = 0
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.preprocess_fn = preprocess_fn
|
||||
# if preprocess_fn is None:
|
||||
# def _prep(**kwargs):
|
||||
# return kwargs
|
||||
# self.preprocess_fn = _prep
|
||||
self.process_fn = policy.process_fn
|
||||
self._multi_env = isinstance(env, BaseVectorEnv)
|
||||
self._multi_buf = False # True if buf is a list
|
||||
@ -119,7 +134,7 @@ class Collector(object):
|
||||
self.buffer.reset()
|
||||
|
||||
def get_env_num(self):
|
||||
"""Return the number of environments the collector has."""
|
||||
"""Return the number of environments the collector have."""
|
||||
return self.env_num
|
||||
|
||||
def reset_env(self):
|
||||
@ -127,6 +142,10 @@ class Collector(object):
|
||||
buffers (if need).
|
||||
"""
|
||||
self._obs = self.env.reset()
|
||||
if not self._multi_env:
|
||||
self._obs = self._make_batch(self._obs)
|
||||
if self.preprocess_fn:
|
||||
self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs)
|
||||
self._act = self._rew = self._done = self._info = None
|
||||
if self._multi_env:
|
||||
self.reward = np.zeros(self.env_num)
|
||||
@ -231,40 +250,43 @@ class Collector(object):
|
||||
'There are already many steps in an episode. '
|
||||
'You should add a time limitation to your environment!',
|
||||
Warning)
|
||||
if self._multi_env:
|
||||
batch_data = Batch(
|
||||
obs=self._obs, act=self._act, rew=self._rew,
|
||||
done=self._done, obs_next=None, info=self._info)
|
||||
else:
|
||||
batch_data = Batch(
|
||||
obs=self._make_batch(self._obs),
|
||||
act=self._make_batch(self._act),
|
||||
rew=self._make_batch(self._rew),
|
||||
done=self._make_batch(self._done),
|
||||
obs_next=None,
|
||||
info=self._make_batch(self._info))
|
||||
batch = Batch(
|
||||
obs=self._obs, act=self._act, rew=self._rew,
|
||||
done=self._done, obs_next=None, info=self._info,
|
||||
policy=None)
|
||||
with torch.no_grad():
|
||||
result = self.policy(batch_data, self.state)
|
||||
self.state = result.state if hasattr(result, 'state') else None
|
||||
result = self.policy(batch, self.state)
|
||||
self.state = result.get('state', None)
|
||||
self._policy = self._to_numpy(result.policy) \
|
||||
if hasattr(result, 'policy') \
|
||||
else [{}] * self.env_num if self._multi_env else {}
|
||||
if isinstance(result.act, torch.Tensor):
|
||||
self._act = self._to_numpy(result.act)
|
||||
elif not isinstance(self._act, np.ndarray):
|
||||
self._act = np.array(result.act)
|
||||
else:
|
||||
self._act = result.act
|
||||
if hasattr(result, 'policy') else [{}] * self.env_num
|
||||
self._act = self._to_numpy(result.act)
|
||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||
self._act if self._multi_env else self._act[0])
|
||||
if log_fn is not None:
|
||||
log_fn(self._info)
|
||||
if render is not None:
|
||||
if not self._multi_env:
|
||||
obs_next = self._make_batch(obs_next)
|
||||
self._rew = self._make_batch(self._rew)
|
||||
self._done = self._make_batch(self._done)
|
||||
self._info = self._make_batch(self._info)
|
||||
if log_fn:
|
||||
log_fn(self._info if self._multi_env else self._info[0])
|
||||
if render:
|
||||
self.env.render()
|
||||
if render > 0:
|
||||
time.sleep(render)
|
||||
self.length += 1
|
||||
self.reward += self._rew
|
||||
if self.preprocess_fn:
|
||||
result = self.preprocess_fn(
|
||||
obs=self._obs, act=self._act, rew=self._rew,
|
||||
done=self._done, obs_next=obs_next, info=self._info,
|
||||
policy=self._policy)
|
||||
self._obs = result.get('obs', self._obs)
|
||||
self._act = result.get('act', self._act)
|
||||
self._rew = result.get('rew', self._rew)
|
||||
self._done = result.get('done', self._done)
|
||||
obs_next = result.get('obs_next', obs_next)
|
||||
self._info = result.get('info', self._info)
|
||||
self._policy = result.get('policy', self._policy)
|
||||
if self._multi_env:
|
||||
for i in range(self.env_num):
|
||||
data = {
|
||||
@ -300,6 +322,9 @@ class Collector(object):
|
||||
self._reset_state(i)
|
||||
if sum(self._done):
|
||||
obs_next = self.env.reset(np.where(self._done)[0])
|
||||
if self.preprocess_fn:
|
||||
obs_next = self.preprocess_fn(obs=obs_next).get(
|
||||
'obs', obs_next)
|
||||
if n_episode != 0:
|
||||
if isinstance(n_episode, list) and \
|
||||
(cur_episode >= np.array(n_episode)).all() or \
|
||||
@ -309,16 +334,20 @@ class Collector(object):
|
||||
else:
|
||||
if self.buffer is not None:
|
||||
self.buffer.add(
|
||||
self._obs, self._act[0], self._rew,
|
||||
self._done, obs_next, self._info, self._policy)
|
||||
self._obs[0], self._act[0], self._rew[0],
|
||||
self._done[0], obs_next[0], self._info[0],
|
||||
self._policy[0])
|
||||
cur_step += 1
|
||||
if self._done:
|
||||
cur_episode += 1
|
||||
reward_sum += self.reward
|
||||
reward_sum += self.reward[0]
|
||||
length_sum += self.length
|
||||
self.reward, self.length = 0, 0
|
||||
self.state = None
|
||||
obs_next = self.env.reset()
|
||||
obs_next = self._make_batch(self.env.reset())
|
||||
if self.preprocess_fn:
|
||||
obs_next = self.preprocess_fn(obs=obs_next).get(
|
||||
'obs', obs_next)
|
||||
if n_episode != 0 and cur_episode >= n_episode:
|
||||
break
|
||||
if n_step != 0 and cur_step >= n_step:
|
||||
|
@ -10,7 +10,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
batch_size,
|
||||
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||
task='', **kwargs):
|
||||
**kwargs):
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
@ -89,8 +89,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
k + '_' + task if task else k,
|
||||
result[k], global_step=global_step)
|
||||
k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
@ -98,8 +97,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
data[k] = f'{stat[k].get():.6f}'
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
k + '_' + task if task else k,
|
||||
stat[k].get(), global_step=global_step)
|
||||
k, stat[k].get(), global_step=global_step)
|
||||
t.update(1)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
|
@ -10,7 +10,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
episode_per_test, batch_size,
|
||||
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||
task='', **kwargs):
|
||||
**kwargs):
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
@ -97,8 +97,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
k + '_' + task if task else k,
|
||||
result[k], global_step=global_step)
|
||||
k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
@ -106,8 +105,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
data[k] = f'{stat[k].get():.6f}'
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
k + '_' + task if task else k,
|
||||
stat[k].get(), global_step=global_step)
|
||||
k, stat[k].get(), global_step=global_step)
|
||||
t.update(step)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
|
Loading…
x
Reference in New Issue
Block a user