add preprocess_fn (#42)
This commit is contained in:
parent
04b091d975
commit
075825325e
@ -29,13 +29,28 @@ def equal(a, b):
|
|||||||
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_fn(**kwargs):
|
||||||
|
# modify info before adding into the buffer
|
||||||
|
if kwargs.get('info', None) is not None:
|
||||||
|
n = len(kwargs['obs'])
|
||||||
|
info = kwargs['info']
|
||||||
|
for i in range(n):
|
||||||
|
info[i].update(rew=kwargs['rew'][i])
|
||||||
|
return {'info': info}
|
||||||
|
# or
|
||||||
|
# return Batch(info=info)
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class Logger(object):
|
class Logger(object):
|
||||||
def __init__(self, writer):
|
def __init__(self, writer):
|
||||||
self.cnt = 0
|
self.cnt = 0
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
|
|
||||||
def log(self, info):
|
def log(self, info):
|
||||||
self.writer.add_scalar('key', info['key'], global_step=self.cnt)
|
self.writer.add_scalar(
|
||||||
|
'key', np.mean(info['key']), global_step=self.cnt)
|
||||||
self.cnt += 1
|
self.cnt += 1
|
||||||
|
|
||||||
|
|
||||||
@ -52,21 +67,24 @@ def test_collector():
|
|||||||
venv = SubprocVectorEnv(env_fns)
|
venv = SubprocVectorEnv(env_fns)
|
||||||
policy = MyPolicy()
|
policy = MyPolicy()
|
||||||
env = env_fns[0]()
|
env = env_fns[0]()
|
||||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
|
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||||
|
preprocess_fn)
|
||||||
c0.collect(n_step=3, log_fn=logger.log)
|
c0.collect(n_step=3, log_fn=logger.log)
|
||||||
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
||||||
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
|
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
|
||||||
c0.collect(n_episode=3, log_fn=logger.log)
|
c0.collect(n_episode=3, log_fn=logger.log)
|
||||||
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||||
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
||||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||||
|
preprocess_fn)
|
||||||
c1.collect(n_step=6)
|
c1.collect(n_step=6)
|
||||||
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
||||||
assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
||||||
c1.collect(n_episode=2)
|
c1.collect(n_episode=2)
|
||||||
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
||||||
assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||||
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||||
|
preprocess_fn)
|
||||||
c2.collect(n_episode=[1, 2, 2, 2])
|
c2.collect(n_episode=[1, 2, 2, 2])
|
||||||
assert equal(c2.buffer.obs_next[:26], [
|
assert equal(c2.buffer.obs_next[:26], [
|
||||||
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||||
@ -81,7 +99,7 @@ def test_collector():
|
|||||||
def test_collector_with_dict_state():
|
def test_collector_with_dict_state():
|
||||||
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
||||||
policy = MyPolicy(dict_state=True)
|
policy = MyPolicy(dict_state=True)
|
||||||
c0 = Collector(policy, env, ReplayBuffer(size=100))
|
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
|
||||||
c0.collect(n_step=3)
|
c0.collect(n_step=3)
|
||||||
c0.collect(n_episode=3)
|
c0.collect(n_episode=3)
|
||||||
env_fns = [
|
env_fns = [
|
||||||
@ -91,7 +109,7 @@ def test_collector_with_dict_state():
|
|||||||
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
|
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
|
||||||
]
|
]
|
||||||
envs = VectorEnv(env_fns)
|
envs = VectorEnv(env_fns)
|
||||||
c1 = Collector(policy, envs, ReplayBuffer(size=100))
|
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
|
||||||
c1.collect(n_step=10)
|
c1.collect(n_step=10)
|
||||||
c1.collect(n_episode=[2, 1, 1, 2])
|
c1.collect(n_episode=[2, 1, 1, 2])
|
||||||
batch = c1.sample(10)
|
batch = c1.sample(10)
|
||||||
@ -101,7 +119,8 @@ def test_collector_with_dict_state():
|
|||||||
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
||||||
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
||||||
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
|
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
|
||||||
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4))
|
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
||||||
|
preprocess_fn)
|
||||||
c2.collect(n_episode=[0, 0, 0, 10])
|
c2.collect(n_episode=[0, 0, 0, 10])
|
||||||
batch = c2.sample(10)
|
batch = c2.sample(10)
|
||||||
print(batch['obs_next']['index'])
|
print(batch['obs_next']['index'])
|
||||||
|
@ -130,6 +130,12 @@ class Batch(object):
|
|||||||
return sorted([i for i in self.__dict__ if i[0] != '_'] +
|
return sorted([i for i in self.__dict__ if i[0] != '_'] +
|
||||||
list(self._meta))
|
list(self._meta))
|
||||||
|
|
||||||
|
def get(self, k, d=None):
|
||||||
|
"""Return self[k] if k in self else d. d defaults to None."""
|
||||||
|
if k in self.__dict__ or k in self._meta:
|
||||||
|
return self.__getattr__(k)
|
||||||
|
return d
|
||||||
|
|
||||||
def to_numpy(self):
|
def to_numpy(self):
|
||||||
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
||||||
operation.
|
operation.
|
||||||
|
@ -14,15 +14,24 @@ class Collector(object):
|
|||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
class.
|
class.
|
||||||
:param env: an environment or an instance of the
|
:param env: a ``gym.Env`` environment or an instance of the
|
||||||
:class:`~tianshou.env.BaseVectorEnv` class.
|
:class:`~tianshou.env.BaseVectorEnv` class.
|
||||||
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
|
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
|
||||||
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
|
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
|
||||||
``None``, it will automatically assign a small-size
|
``None``, it will automatically assign a small-size
|
||||||
:class:`~tianshou.data.ReplayBuffer`.
|
:class:`~tianshou.data.ReplayBuffer`.
|
||||||
|
:param function preprocess_fn: a function called before the data has been
|
||||||
|
added to the buffer, see issue #42, defaults to ``None``.
|
||||||
:param int stat_size: for the moving average of recording speed, defaults
|
:param int stat_size: for the moving average of recording speed, defaults
|
||||||
to 100.
|
to 100.
|
||||||
|
|
||||||
|
The ``preprocess_fn`` is a function called before the data has been added
|
||||||
|
to the buffer with batch format, which receives up to 7 keys as listed in
|
||||||
|
:class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
|
||||||
|
collector resets the environment. It returns either a dict or a
|
||||||
|
:class:`~tianshou.data.Batch` with the modified keys and values. Examples
|
||||||
|
are in "test/base/test_collector.py".
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
::
|
::
|
||||||
|
|
||||||
@ -68,15 +77,21 @@ class Collector(object):
|
|||||||
Please make sure the given environment has a time limitation.
|
Please make sure the given environment has a time limitation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs):
|
def __init__(self, policy, env, buffer=None, preprocess_fn=None,
|
||||||
|
stat_size=100, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = 1
|
self.env_num = 1
|
||||||
|
self.collect_time = 0
|
||||||
self.collect_step = 0
|
self.collect_step = 0
|
||||||
self.collect_episode = 0
|
self.collect_episode = 0
|
||||||
self.collect_time = 0
|
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
|
self.preprocess_fn = preprocess_fn
|
||||||
|
# if preprocess_fn is None:
|
||||||
|
# def _prep(**kwargs):
|
||||||
|
# return kwargs
|
||||||
|
# self.preprocess_fn = _prep
|
||||||
self.process_fn = policy.process_fn
|
self.process_fn = policy.process_fn
|
||||||
self._multi_env = isinstance(env, BaseVectorEnv)
|
self._multi_env = isinstance(env, BaseVectorEnv)
|
||||||
self._multi_buf = False # True if buf is a list
|
self._multi_buf = False # True if buf is a list
|
||||||
@ -119,7 +134,7 @@ class Collector(object):
|
|||||||
self.buffer.reset()
|
self.buffer.reset()
|
||||||
|
|
||||||
def get_env_num(self):
|
def get_env_num(self):
|
||||||
"""Return the number of environments the collector has."""
|
"""Return the number of environments the collector have."""
|
||||||
return self.env_num
|
return self.env_num
|
||||||
|
|
||||||
def reset_env(self):
|
def reset_env(self):
|
||||||
@ -127,6 +142,10 @@ class Collector(object):
|
|||||||
buffers (if need).
|
buffers (if need).
|
||||||
"""
|
"""
|
||||||
self._obs = self.env.reset()
|
self._obs = self.env.reset()
|
||||||
|
if not self._multi_env:
|
||||||
|
self._obs = self._make_batch(self._obs)
|
||||||
|
if self.preprocess_fn:
|
||||||
|
self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs)
|
||||||
self._act = self._rew = self._done = self._info = None
|
self._act = self._rew = self._done = self._info = None
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
self.reward = np.zeros(self.env_num)
|
self.reward = np.zeros(self.env_num)
|
||||||
@ -231,40 +250,43 @@ class Collector(object):
|
|||||||
'There are already many steps in an episode. '
|
'There are already many steps in an episode. '
|
||||||
'You should add a time limitation to your environment!',
|
'You should add a time limitation to your environment!',
|
||||||
Warning)
|
Warning)
|
||||||
if self._multi_env:
|
batch = Batch(
|
||||||
batch_data = Batch(
|
obs=self._obs, act=self._act, rew=self._rew,
|
||||||
obs=self._obs, act=self._act, rew=self._rew,
|
done=self._done, obs_next=None, info=self._info,
|
||||||
done=self._done, obs_next=None, info=self._info)
|
policy=None)
|
||||||
else:
|
|
||||||
batch_data = Batch(
|
|
||||||
obs=self._make_batch(self._obs),
|
|
||||||
act=self._make_batch(self._act),
|
|
||||||
rew=self._make_batch(self._rew),
|
|
||||||
done=self._make_batch(self._done),
|
|
||||||
obs_next=None,
|
|
||||||
info=self._make_batch(self._info))
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = self.policy(batch_data, self.state)
|
result = self.policy(batch, self.state)
|
||||||
self.state = result.state if hasattr(result, 'state') else None
|
self.state = result.get('state', None)
|
||||||
self._policy = self._to_numpy(result.policy) \
|
self._policy = self._to_numpy(result.policy) \
|
||||||
if hasattr(result, 'policy') \
|
if hasattr(result, 'policy') else [{}] * self.env_num
|
||||||
else [{}] * self.env_num if self._multi_env else {}
|
self._act = self._to_numpy(result.act)
|
||||||
if isinstance(result.act, torch.Tensor):
|
|
||||||
self._act = self._to_numpy(result.act)
|
|
||||||
elif not isinstance(self._act, np.ndarray):
|
|
||||||
self._act = np.array(result.act)
|
|
||||||
else:
|
|
||||||
self._act = result.act
|
|
||||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||||
self._act if self._multi_env else self._act[0])
|
self._act if self._multi_env else self._act[0])
|
||||||
if log_fn is not None:
|
if not self._multi_env:
|
||||||
log_fn(self._info)
|
obs_next = self._make_batch(obs_next)
|
||||||
if render is not None:
|
self._rew = self._make_batch(self._rew)
|
||||||
|
self._done = self._make_batch(self._done)
|
||||||
|
self._info = self._make_batch(self._info)
|
||||||
|
if log_fn:
|
||||||
|
log_fn(self._info if self._multi_env else self._info[0])
|
||||||
|
if render:
|
||||||
self.env.render()
|
self.env.render()
|
||||||
if render > 0:
|
if render > 0:
|
||||||
time.sleep(render)
|
time.sleep(render)
|
||||||
self.length += 1
|
self.length += 1
|
||||||
self.reward += self._rew
|
self.reward += self._rew
|
||||||
|
if self.preprocess_fn:
|
||||||
|
result = self.preprocess_fn(
|
||||||
|
obs=self._obs, act=self._act, rew=self._rew,
|
||||||
|
done=self._done, obs_next=obs_next, info=self._info,
|
||||||
|
policy=self._policy)
|
||||||
|
self._obs = result.get('obs', self._obs)
|
||||||
|
self._act = result.get('act', self._act)
|
||||||
|
self._rew = result.get('rew', self._rew)
|
||||||
|
self._done = result.get('done', self._done)
|
||||||
|
obs_next = result.get('obs_next', obs_next)
|
||||||
|
self._info = result.get('info', self._info)
|
||||||
|
self._policy = result.get('policy', self._policy)
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
for i in range(self.env_num):
|
for i in range(self.env_num):
|
||||||
data = {
|
data = {
|
||||||
@ -300,6 +322,9 @@ class Collector(object):
|
|||||||
self._reset_state(i)
|
self._reset_state(i)
|
||||||
if sum(self._done):
|
if sum(self._done):
|
||||||
obs_next = self.env.reset(np.where(self._done)[0])
|
obs_next = self.env.reset(np.where(self._done)[0])
|
||||||
|
if self.preprocess_fn:
|
||||||
|
obs_next = self.preprocess_fn(obs=obs_next).get(
|
||||||
|
'obs', obs_next)
|
||||||
if n_episode != 0:
|
if n_episode != 0:
|
||||||
if isinstance(n_episode, list) and \
|
if isinstance(n_episode, list) and \
|
||||||
(cur_episode >= np.array(n_episode)).all() or \
|
(cur_episode >= np.array(n_episode)).all() or \
|
||||||
@ -309,16 +334,20 @@ class Collector(object):
|
|||||||
else:
|
else:
|
||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
self.buffer.add(
|
self.buffer.add(
|
||||||
self._obs, self._act[0], self._rew,
|
self._obs[0], self._act[0], self._rew[0],
|
||||||
self._done, obs_next, self._info, self._policy)
|
self._done[0], obs_next[0], self._info[0],
|
||||||
|
self._policy[0])
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
if self._done:
|
if self._done:
|
||||||
cur_episode += 1
|
cur_episode += 1
|
||||||
reward_sum += self.reward
|
reward_sum += self.reward[0]
|
||||||
length_sum += self.length
|
length_sum += self.length
|
||||||
self.reward, self.length = 0, 0
|
self.reward, self.length = 0, 0
|
||||||
self.state = None
|
self.state = None
|
||||||
obs_next = self.env.reset()
|
obs_next = self._make_batch(self.env.reset())
|
||||||
|
if self.preprocess_fn:
|
||||||
|
obs_next = self.preprocess_fn(obs=obs_next).get(
|
||||||
|
'obs', obs_next)
|
||||||
if n_episode != 0 and cur_episode >= n_episode:
|
if n_episode != 0 and cur_episode >= n_episode:
|
||||||
break
|
break
|
||||||
if n_step != 0 and cur_step >= n_step:
|
if n_step != 0 and cur_step >= n_step:
|
||||||
|
@ -10,7 +10,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
batch_size,
|
batch_size,
|
||||||
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||||
log_fn=None, writer=None, log_interval=1, verbose=True,
|
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||||
task='', **kwargs):
|
**kwargs):
|
||||||
"""A wrapper for off-policy trainer procedure.
|
"""A wrapper for off-policy trainer procedure.
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
@ -89,8 +89,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
data[k] = f'{result[k]:.2f}'
|
data[k] = f'{result[k]:.2f}'
|
||||||
if writer and global_step % log_interval == 0:
|
if writer and global_step % log_interval == 0:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k + '_' + task if task else k,
|
k, result[k], global_step=global_step)
|
||||||
result[k], global_step=global_step)
|
|
||||||
for k in losses.keys():
|
for k in losses.keys():
|
||||||
if stat.get(k) is None:
|
if stat.get(k) is None:
|
||||||
stat[k] = MovAvg()
|
stat[k] = MovAvg()
|
||||||
@ -98,8 +97,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
data[k] = f'{stat[k].get():.6f}'
|
data[k] = f'{stat[k].get():.6f}'
|
||||||
if writer and global_step % log_interval == 0:
|
if writer and global_step % log_interval == 0:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k + '_' + task if task else k,
|
k, stat[k].get(), global_step=global_step)
|
||||||
stat[k].get(), global_step=global_step)
|
|
||||||
t.update(1)
|
t.update(1)
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
|
@ -10,7 +10,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
episode_per_test, batch_size,
|
episode_per_test, batch_size,
|
||||||
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||||
log_fn=None, writer=None, log_interval=1, verbose=True,
|
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||||
task='', **kwargs):
|
**kwargs):
|
||||||
"""A wrapper for on-policy trainer procedure.
|
"""A wrapper for on-policy trainer procedure.
|
||||||
|
|
||||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||||
@ -97,8 +97,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
data[k] = f'{result[k]:.2f}'
|
data[k] = f'{result[k]:.2f}'
|
||||||
if writer and global_step % log_interval == 0:
|
if writer and global_step % log_interval == 0:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k + '_' + task if task else k,
|
k, result[k], global_step=global_step)
|
||||||
result[k], global_step=global_step)
|
|
||||||
for k in losses.keys():
|
for k in losses.keys():
|
||||||
if stat.get(k) is None:
|
if stat.get(k) is None:
|
||||||
stat[k] = MovAvg()
|
stat[k] = MovAvg()
|
||||||
@ -106,8 +105,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
data[k] = f'{stat[k].get():.6f}'
|
data[k] = f'{stat[k].get():.6f}'
|
||||||
if writer and global_step % log_interval == 0:
|
if writer and global_step % log_interval == 0:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k + '_' + task if task else k,
|
k, stat[k].get(), global_step=global_step)
|
||||||
stat[k].get(), global_step=global_step)
|
|
||||||
t.update(step)
|
t.update(step)
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user