From 5ed6c1c7aaff15e5f8433c804a1f2a635734a92c Mon Sep 17 00:00:00 2001 From: n+e Date: Sun, 4 Oct 2020 21:55:43 +0800 Subject: [PATCH] change the step in trainer (#235) This PR separates the `global_step` into `env_step` and `gradient_step`. In the future, the data from the collecting state will be stored under `env_step`, and the data from the updating state will be stored under `gradient_step`. Others: - add `rew_std` and `best_result` into the monitor - fix network unbounded in `test/continuous/test_sac_with_il.py` and `examples/box2d/bipedal_hardcore_sac.py` - change the dependency of ray to 1.0.0 since ray-project/ray#10134 has been resolved --- docs/tutorials/cheatsheet.rst | 4 +- examples/box2d/bipedal_hardcore_sac.py | 3 +- setup.py | 2 +- test/base/test_env.py | 11 +++--- test/continuous/test_sac_with_il.py | 10 ++--- tianshou/data/collector.py | 12 +++--- tianshou/trainer/offpolicy.py | 49 +++++++++++++---------- tianshou/trainer/onpolicy.py | 55 +++++++++++++++----------- tianshou/trainer/utils.py | 2 + 9 files changed, 81 insertions(+), 67 deletions(-) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index c088a8d..07155f3 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -3,7 +3,7 @@ Cheat Sheet This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios. -By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`. +By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`. .. _network_api: @@ -163,7 +163,7 @@ First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`: Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`. -The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state: +The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state: - Before: (s, a, s', r, d) stored in replay buffer, and get stacked s; - After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a. diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index b4da185..0ddc682 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -93,8 +93,7 @@ def test_sac_bipedal(args=get_args()): # model net_a = Net(args.layer_num, args.state_shape, device=args.device) actor = ActorProb( - net_a, args.action_shape, - args.max_action, args.device + net_a, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) diff --git a/setup.py b/setup.py index 93acc4a..a7410ce 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ setup( "flake8", "pytest", "pytest-cov", - "ray>=0.8.0", + "ray>=1.0.0", "mypy", "pydocstyle", "doc8", diff --git a/test/base/test_env.py b/test/base/test_env.py index 8d2c78a..82c0b1a 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -111,13 +111,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7): t = time.time() - t ids = Batch(info).env_id print(ids, t) - if cls != RayVectorEnv: # ray-project/ray#10134 - if not (len(ids) == len(res) and np.allclose(sorted(ids), res) - and (t < timeout) == (len(res) == num - 1)): - pass_check = 0 - break + if not (len(ids) == len(res) and np.allclose(sorted(ids), res) + and (t < timeout) == (len(res) == num - 1)): + pass_check = 0 + break total_pass += pass_check - assert total_pass >= 1 # should be modified when ray>=0.9.0 release + assert total_pass >= 2 def test_vecenv(size=10, num=8, sleep=0.001): diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 009218c..865a1c8 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -67,8 +67,7 @@ def test_sac_with_il(args=get_args()): # model net = Net(args.layer_num, args.state_shape, device=args.device) actor = ActorProb( - net, args.action_shape, - args.max_action, args.device + net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, @@ -118,9 +117,10 @@ def test_sac_with_il(args=get_args()): policy.eval() if args.task == 'Pendulum-v0': env.spec.reward_threshold = -300 # lower the goal - net = Actor(Net(1, args.state_shape, device=args.device), - args.action_shape, args.max_action, args.device - ).to(args.device) + net = Actor( + Net(1, args.state_shape, device=args.device), + args.action_shape, args.max_action, args.device + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='continuous') il_test_collector = Collector( diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 0d755fb..25f268d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -214,7 +214,7 @@ class Collector(object): # number of episodes, these envs will be recorded in this list, and # they will not be stepped. finished_env_ids = [] - reward_total = 0.0 + rewards = [] whole_data = Batch() if isinstance(n_episode, list): assert len(n_episode) == self.get_env_num() @@ -311,7 +311,8 @@ class Collector(object): if not (isinstance(n_episode, list) and episode_count[i] >= n_episode[i]): episode_count[i] += 1 - reward_total += np.sum(self._cached_buf[i].rew, axis=0) + rewards.append(self._rew_metric( + np.sum(self._cached_buf[i].rew, axis=0))) step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) @@ -361,16 +362,13 @@ class Collector(object): self.collect_step += step_count self.collect_episode += episode_count self.collect_time += duration - # average reward across the number of episodes - reward_avg = reward_total / episode_count - if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg - reward_avg = self._rew_metric(reward_avg) # type: ignore return { "n/ep": episode_count, "n/st": step_count, "v/st": step_count / duration, "v/ep": episode_count / duration, - "rew": reward_avg, + "rew": np.mean(rewards), + "rew_std": np.std(rewards), "len": step_count / episode_count, } diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 75fd6cf..fb4b6f4 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -71,8 +71,8 @@ def offpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ - global_step = 0 - best_epoch, best_reward = -1, -1.0 + env_step, gradient_step = 0, 0 + best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = {} start_time = time.time() train_collector.reset_stat() @@ -86,13 +86,26 @@ def offpolicy_trainer( ) as t: while t.n < t.total: if train_fn: - train_fn(epoch, global_step) + train_fn(epoch, env_step) result = train_collector.collect(n_step=collect_per_step) - data = {} + env_step += int(result["n/st"]) + data = { + "env_step": str(env_step), + "rew": f"{result['rew']:.2f}", + "len": str(int(result["len"])), + "n/ep": str(int(result["n/ep"])), + "n/st": str(int(result["n/st"])), + "v/ep": f"{result['v/ep']:.2f}", + "v/st": f"{result['v/st']:.2f}", + } + if writer and env_step % log_interval == 0: + for k in result.keys(): + writer.add_scalar( + "train/" + k, result[k], global_step=env_step) if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( policy, test_collector, test_fn, - epoch, episode_per_test, writer, global_step) + epoch, episode_per_test, writer, env_step) if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) @@ -101,42 +114,38 @@ def offpolicy_trainer( t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rew"]) + test_result["rew"], test_result["rew_std"]) else: policy.train() for i in range(update_per_step * min( result["n/st"] // collect_per_step, t.total - t.n)): - global_step += collect_per_step + gradient_step += 1 losses = policy.update(batch_size, train_collector.buffer) - for k in result.keys(): - data[k] = f"{result[k]:.2f}" - if writer and global_step % log_interval == 0: - writer.add_scalar("train/" + k, result[k], - global_step=global_step) for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() stat[k].add(losses[k]) data[k] = f"{stat[k].get():.6f}" - if writer and global_step % log_interval == 0: + if writer and gradient_step % log_interval == 0: writer.add_scalar( - k, stat[k].get(), global_step=global_step) + k, stat[k].get(), global_step=gradient_step) t.update(1) t.set_postfix(**data) if t.n <= t.total: t.update() # test result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, global_step) + episode_per_test, writer, env_step) if best_epoch == -1 or best_reward < result["rew"]: - best_reward = result["rew"] + best_reward, best_reward_std = result["rew"], result["rew_std"] best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, " - f"best_reward: {best_reward:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " + f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " + f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break - return gather_info( - start_time, train_collector, test_collector, best_reward) + return gather_info(start_time, train_collector, test_collector, + best_reward, best_reward_std) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 023dd29..5aff68b 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -71,8 +71,8 @@ def onpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ - global_step = 0 - best_epoch, best_reward = -1, -1.0 + env_step, gradient_step = 0, 0 + best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = {} start_time = time.time() train_collector.reset_stat() @@ -86,13 +86,26 @@ def onpolicy_trainer( ) as t: while t.n < t.total: if train_fn: - train_fn(epoch, global_step) + train_fn(epoch, env_step) result = train_collector.collect(n_episode=collect_per_step) - data = {} + env_step += int(result["n/st"]) + data = { + "env_step": str(env_step), + "rew": f"{result['rew']:.2f}", + "len": str(int(result["len"])), + "n/ep": str(int(result["n/ep"])), + "n/st": str(int(result["n/st"])), + "v/ep": f"{result['v/ep']:.2f}", + "v/st": f"{result['v/st']:.2f}", + } + if writer and env_step % log_interval == 0: + for k in result.keys(): + writer.add_scalar( + "train/" + k, result[k], global_step=env_step) if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( policy, test_collector, test_fn, - epoch, episode_per_test, writer, global_step) + epoch, episode_per_test, writer, env_step) if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) @@ -101,47 +114,41 @@ def onpolicy_trainer( t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rew"]) + test_result["rew"], test_result["rew_std"]) else: policy.train() losses = policy.update( 0, train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect) train_collector.reset_buffer() - step = 1 - for v in losses.values(): - if isinstance(v, list): - step = max(step, len(v)) - global_step += step * collect_per_step - for k in result.keys(): - data[k] = f"{result[k]:.2f}" - if writer and global_step % log_interval == 0: - writer.add_scalar( - "train/" + k, result[k], global_step=global_step) + step = max([1] + [ + len(v) for v in losses.values() if isinstance(v, list)]) + gradient_step += step for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() stat[k].add(losses[k]) data[k] = f"{stat[k].get():.6f}" - if writer and global_step % log_interval == 0: + if writer and gradient_step % log_interval == 0: writer.add_scalar( - k, stat[k].get(), global_step=global_step) + k, stat[k].get(), global_step=gradient_step) t.update(step) t.set_postfix(**data) if t.n <= t.total: t.update() # test result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, global_step) + episode_per_test, writer, env_step) if best_epoch == -1 or best_reward < result["rew"]: - best_reward = result["rew"] + best_reward, best_reward_std = result["rew"], result["rew_std"] best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, " - f"best_reward: {best_reward:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " + f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " + f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break - return gather_info( - start_time, train_collector, test_collector, best_reward) + return gather_info(start_time, train_collector, test_collector, + best_reward, best_reward_std) diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 2c2fb54..da9dea8 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -39,6 +39,7 @@ def gather_info( train_c: Collector, test_c: Collector, best_reward: float, + best_reward_std: float, ) -> Dict[str, Union[float, str]]: """A simple wrapper of gathering information from collectors. @@ -72,5 +73,6 @@ def gather_info( "test_time": f"{test_c.collect_time:.2f}s", "test_speed": f"{test_speed:.2f} step/s", "best_reward": best_reward, + "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", "duration": f"{duration:.2f}s", }