change the step in trainer (#235)

This PR separates the `global_step` into `env_step` and `gradient_step`. In the future, the data from the collecting state will be stored under `env_step`, and the data from the updating state will be stored under `gradient_step`.

Others:
- add `rew_std` and `best_result` into the monitor
- fix network unbounded in `test/continuous/test_sac_with_il.py` and `examples/box2d/bipedal_hardcore_sac.py`
- change the dependency of ray to 1.0.0 since ray-project/ray#10134 has been resolved
This commit is contained in:
n+e 2020-10-04 21:55:43 +08:00 committed by GitHub
parent 710966eda7
commit 5ed6c1c7aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 81 additions and 67 deletions

View File

@ -3,7 +3,7 @@ Cheat Sheet
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios. This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`. By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`.
.. _network_api: .. _network_api:
@ -163,7 +163,7 @@ First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`:
Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`. Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`.
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state: The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state:
- Before: (s, a, s', r, d) stored in replay buffer, and get stacked s; - Before: (s, a, s', r, d) stored in replay buffer, and get stacked s;
- After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a. - After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a.

View File

@ -93,8 +93,7 @@ def test_sac_bipedal(args=get_args()):
# model # model
net_a = Net(args.layer_num, args.state_shape, device=args.device) net_a = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb( actor = ActorProb(
net_a, args.action_shape, net_a, args.action_shape, args.max_action, args.device, unbounded=True
args.max_action, args.device
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

View File

@ -60,7 +60,7 @@ setup(
"flake8", "flake8",
"pytest", "pytest",
"pytest-cov", "pytest-cov",
"ray>=0.8.0", "ray>=1.0.0",
"mypy", "mypy",
"pydocstyle", "pydocstyle",
"doc8", "doc8",

View File

@ -111,13 +111,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
t = time.time() - t t = time.time() - t
ids = Batch(info).env_id ids = Batch(info).env_id
print(ids, t) print(ids, t)
if cls != RayVectorEnv: # ray-project/ray#10134 if not (len(ids) == len(res) and np.allclose(sorted(ids), res)
if not (len(ids) == len(res) and np.allclose(sorted(ids), res) and (t < timeout) == (len(res) == num - 1)):
and (t < timeout) == (len(res) == num - 1)): pass_check = 0
pass_check = 0 break
break
total_pass += pass_check total_pass += pass_check
assert total_pass >= 1 # should be modified when ray>=0.9.0 release assert total_pass >= 2
def test_vecenv(size=10, num=8, sleep=0.001): def test_vecenv(size=10, num=8, sleep=0.001):

View File

@ -67,8 +67,7 @@ def test_sac_with_il(args=get_args()):
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device) net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb( actor = ActorProb(
net, args.action_shape, net, args.action_shape, args.max_action, args.device, unbounded=True
args.max_action, args.device
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape, net = Net(args.layer_num, args.state_shape,
@ -118,9 +117,10 @@ def test_sac_with_il(args=get_args()):
policy.eval() policy.eval()
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -300 # lower the goal env.spec.reward_threshold = -300 # lower the goal
net = Actor(Net(1, args.state_shape, device=args.device), net = Actor(
args.action_shape, args.max_action, args.device Net(1, args.state_shape, device=args.device),
).to(args.device) args.action_shape, args.max_action, args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='continuous') il_policy = ImitationPolicy(net, optim, mode='continuous')
il_test_collector = Collector( il_test_collector = Collector(

View File

@ -214,7 +214,7 @@ class Collector(object):
# number of episodes, these envs will be recorded in this list, and # number of episodes, these envs will be recorded in this list, and
# they will not be stepped. # they will not be stepped.
finished_env_ids = [] finished_env_ids = []
reward_total = 0.0 rewards = []
whole_data = Batch() whole_data = Batch()
if isinstance(n_episode, list): if isinstance(n_episode, list):
assert len(n_episode) == self.get_env_num() assert len(n_episode) == self.get_env_num()
@ -311,7 +311,8 @@ class Collector(object):
if not (isinstance(n_episode, list) if not (isinstance(n_episode, list)
and episode_count[i] >= n_episode[i]): and episode_count[i] >= n_episode[i]):
episode_count[i] += 1 episode_count[i] += 1
reward_total += np.sum(self._cached_buf[i].rew, axis=0) rewards.append(self._rew_metric(
np.sum(self._cached_buf[i].rew, axis=0)))
step_count += len(self._cached_buf[i]) step_count += len(self._cached_buf[i])
if self.buffer is not None: if self.buffer is not None:
self.buffer.update(self._cached_buf[i]) self.buffer.update(self._cached_buf[i])
@ -361,16 +362,13 @@ class Collector(object):
self.collect_step += step_count self.collect_step += step_count
self.collect_episode += episode_count self.collect_episode += episode_count
self.collect_time += duration self.collect_time += duration
# average reward across the number of episodes
reward_avg = reward_total / episode_count
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
reward_avg = self._rew_metric(reward_avg) # type: ignore
return { return {
"n/ep": episode_count, "n/ep": episode_count,
"n/st": step_count, "n/st": step_count,
"v/st": step_count / duration, "v/st": step_count / duration,
"v/ep": episode_count / duration, "v/ep": episode_count / duration,
"rew": reward_avg, "rew": np.mean(rewards),
"rew_std": np.std(rewards),
"len": step_count / episode_count, "len": step_count / episode_count,
} }

View File

@ -71,8 +71,8 @@ def offpolicy_trainer(
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
""" """
global_step = 0 env_step, gradient_step = 0, 0
best_epoch, best_reward = -1, -1.0 best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
stat: Dict[str, MovAvg] = {} stat: Dict[str, MovAvg] = {}
start_time = time.time() start_time = time.time()
train_collector.reset_stat() train_collector.reset_stat()
@ -86,13 +86,26 @@ def offpolicy_trainer(
) as t: ) as t:
while t.n < t.total: while t.n < t.total:
if train_fn: if train_fn:
train_fn(epoch, global_step) train_fn(epoch, env_step)
result = train_collector.collect(n_step=collect_per_step) result = train_collector.collect(n_step=collect_per_step)
data = {} env_step += int(result["n/st"])
data = {
"env_step": str(env_step),
"rew": f"{result['rew']:.2f}",
"len": str(int(result["len"])),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
"v/ep": f"{result['v/ep']:.2f}",
"v/st": f"{result['v/st']:.2f}",
}
if writer and env_step % log_interval == 0:
for k in result.keys():
writer.add_scalar(
"train/" + k, result[k], global_step=env_step)
if test_in_train and stop_fn and stop_fn(result["rew"]): if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step) epoch, episode_per_test, writer, env_step)
if stop_fn(test_result["rew"]): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
@ -101,42 +114,38 @@ def offpolicy_trainer(
t.set_postfix(**data) t.set_postfix(**data)
return gather_info( return gather_info(
start_time, train_collector, test_collector, start_time, train_collector, test_collector,
test_result["rew"]) test_result["rew"], test_result["rew_std"])
else: else:
policy.train() policy.train()
for i in range(update_per_step * min( for i in range(update_per_step * min(
result["n/st"] // collect_per_step, t.total - t.n)): result["n/st"] // collect_per_step, t.total - t.n)):
global_step += collect_per_step gradient_step += 1
losses = policy.update(batch_size, train_collector.buffer) losses = policy.update(batch_size, train_collector.buffer)
for k in result.keys():
data[k] = f"{result[k]:.2f}"
if writer and global_step % log_interval == 0:
writer.add_scalar("train/" + k, result[k],
global_step=global_step)
for k in losses.keys(): for k in losses.keys():
if stat.get(k) is None: if stat.get(k) is None:
stat[k] = MovAvg() stat[k] = MovAvg()
stat[k].add(losses[k]) stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}" data[k] = f"{stat[k].get():.6f}"
if writer and global_step % log_interval == 0: if writer and gradient_step % log_interval == 0:
writer.add_scalar( writer.add_scalar(
k, stat[k].get(), global_step=global_step) k, stat[k].get(), global_step=gradient_step)
t.update(1) t.update(1)
t.set_postfix(**data) t.set_postfix(**data)
if t.n <= t.total: if t.n <= t.total:
t.update() t.update()
# test # test
result = test_episode(policy, test_collector, test_fn, epoch, result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, global_step) episode_per_test, writer, env_step)
if best_epoch == -1 or best_reward < result["rew"]: if best_epoch == -1 or best_reward < result["rew"]:
best_reward = result["rew"] best_reward, best_reward_std = result["rew"], result["rew_std"]
best_epoch = epoch best_epoch = epoch
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
if verbose: if verbose:
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, " print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
f"best_reward: {best_reward:.6f} in #{best_epoch}") f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
f"{best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
return gather_info( return gather_info(start_time, train_collector, test_collector,
start_time, train_collector, test_collector, best_reward) best_reward, best_reward_std)

View File

@ -71,8 +71,8 @@ def onpolicy_trainer(
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
""" """
global_step = 0 env_step, gradient_step = 0, 0
best_epoch, best_reward = -1, -1.0 best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
stat: Dict[str, MovAvg] = {} stat: Dict[str, MovAvg] = {}
start_time = time.time() start_time = time.time()
train_collector.reset_stat() train_collector.reset_stat()
@ -86,13 +86,26 @@ def onpolicy_trainer(
) as t: ) as t:
while t.n < t.total: while t.n < t.total:
if train_fn: if train_fn:
train_fn(epoch, global_step) train_fn(epoch, env_step)
result = train_collector.collect(n_episode=collect_per_step) result = train_collector.collect(n_episode=collect_per_step)
data = {} env_step += int(result["n/st"])
data = {
"env_step": str(env_step),
"rew": f"{result['rew']:.2f}",
"len": str(int(result["len"])),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
"v/ep": f"{result['v/ep']:.2f}",
"v/st": f"{result['v/st']:.2f}",
}
if writer and env_step % log_interval == 0:
for k in result.keys():
writer.add_scalar(
"train/" + k, result[k], global_step=env_step)
if test_in_train and stop_fn and stop_fn(result["rew"]): if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step) epoch, episode_per_test, writer, env_step)
if stop_fn(test_result["rew"]): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
@ -101,47 +114,41 @@ def onpolicy_trainer(
t.set_postfix(**data) t.set_postfix(**data)
return gather_info( return gather_info(
start_time, train_collector, test_collector, start_time, train_collector, test_collector,
test_result["rew"]) test_result["rew"], test_result["rew_std"])
else: else:
policy.train() policy.train()
losses = policy.update( losses = policy.update(
0, train_collector.buffer, 0, train_collector.buffer,
batch_size=batch_size, repeat=repeat_per_collect) batch_size=batch_size, repeat=repeat_per_collect)
train_collector.reset_buffer() train_collector.reset_buffer()
step = 1 step = max([1] + [
for v in losses.values(): len(v) for v in losses.values() if isinstance(v, list)])
if isinstance(v, list): gradient_step += step
step = max(step, len(v))
global_step += step * collect_per_step
for k in result.keys():
data[k] = f"{result[k]:.2f}"
if writer and global_step % log_interval == 0:
writer.add_scalar(
"train/" + k, result[k], global_step=global_step)
for k in losses.keys(): for k in losses.keys():
if stat.get(k) is None: if stat.get(k) is None:
stat[k] = MovAvg() stat[k] = MovAvg()
stat[k].add(losses[k]) stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}" data[k] = f"{stat[k].get():.6f}"
if writer and global_step % log_interval == 0: if writer and gradient_step % log_interval == 0:
writer.add_scalar( writer.add_scalar(
k, stat[k].get(), global_step=global_step) k, stat[k].get(), global_step=gradient_step)
t.update(step) t.update(step)
t.set_postfix(**data) t.set_postfix(**data)
if t.n <= t.total: if t.n <= t.total:
t.update() t.update()
# test # test
result = test_episode(policy, test_collector, test_fn, epoch, result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, global_step) episode_per_test, writer, env_step)
if best_epoch == -1 or best_reward < result["rew"]: if best_epoch == -1 or best_reward < result["rew"]:
best_reward = result["rew"] best_reward, best_reward_std = result["rew"], result["rew_std"]
best_epoch = epoch best_epoch = epoch
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
if verbose: if verbose:
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, " print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
f"best_reward: {best_reward:.6f} in #{best_epoch}") f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
f"{best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
return gather_info( return gather_info(start_time, train_collector, test_collector,
start_time, train_collector, test_collector, best_reward) best_reward, best_reward_std)

View File

@ -39,6 +39,7 @@ def gather_info(
train_c: Collector, train_c: Collector,
test_c: Collector, test_c: Collector,
best_reward: float, best_reward: float,
best_reward_std: float,
) -> Dict[str, Union[float, str]]: ) -> Dict[str, Union[float, str]]:
"""A simple wrapper of gathering information from collectors. """A simple wrapper of gathering information from collectors.
@ -72,5 +73,6 @@ def gather_info(
"test_time": f"{test_c.collect_time:.2f}s", "test_time": f"{test_c.collect_time:.2f}s",
"test_speed": f"{test_speed:.2f} step/s", "test_speed": f"{test_speed:.2f} step/s",
"best_reward": best_reward, "best_reward": best_reward,
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
"duration": f"{duration:.2f}s", "duration": f"{duration:.2f}s",
} }