change the step in trainer (#235)
This PR separates the `global_step` into `env_step` and `gradient_step`. In the future, the data from the collecting state will be stored under `env_step`, and the data from the updating state will be stored under `gradient_step`. Others: - add `rew_std` and `best_result` into the monitor - fix network unbounded in `test/continuous/test_sac_with_il.py` and `examples/box2d/bipedal_hardcore_sac.py` - change the dependency of ray to 1.0.0 since ray-project/ray#10134 has been resolved
This commit is contained in:
		
							parent
							
								
									710966eda7
								
							
						
					
					
						commit
						5ed6c1c7aa
					
				| @ -3,7 +3,7 @@ Cheat Sheet | ||||
| 
 | ||||
| This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios. | ||||
| 
 | ||||
| By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`. | ||||
| By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`. | ||||
| 
 | ||||
| 
 | ||||
| .. _network_api: | ||||
| @ -163,7 +163,7 @@ First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`: | ||||
| 
 | ||||
| Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`. | ||||
| 
 | ||||
| The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state: | ||||
| The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state: | ||||
| 
 | ||||
| - Before: (s, a, s', r, d) stored in replay buffer, and get stacked s; | ||||
| - After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a. | ||||
|  | ||||
| @ -93,8 +93,7 @@ def test_sac_bipedal(args=get_args()): | ||||
|     # model | ||||
|     net_a = Net(args.layer_num, args.state_shape, device=args.device) | ||||
|     actor = ActorProb( | ||||
|         net_a, args.action_shape, | ||||
|         args.max_action, args.device | ||||
|         net_a, args.action_shape, args.max_action, args.device, unbounded=True | ||||
|     ).to(args.device) | ||||
|     actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @ -60,7 +60,7 @@ setup( | ||||
|             "flake8", | ||||
|             "pytest", | ||||
|             "pytest-cov", | ||||
|             "ray>=0.8.0", | ||||
|             "ray>=1.0.0", | ||||
|             "mypy", | ||||
|             "pydocstyle", | ||||
|             "doc8", | ||||
|  | ||||
| @ -111,13 +111,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7): | ||||
|             t = time.time() - t | ||||
|             ids = Batch(info).env_id | ||||
|             print(ids, t) | ||||
|             if cls != RayVectorEnv:  # ray-project/ray#10134 | ||||
|                 if not (len(ids) == len(res) and np.allclose(sorted(ids), res) | ||||
|                         and (t < timeout) == (len(res) == num - 1)): | ||||
|                     pass_check = 0 | ||||
|                     break | ||||
|             if not (len(ids) == len(res) and np.allclose(sorted(ids), res) | ||||
|                     and (t < timeout) == (len(res) == num - 1)): | ||||
|                 pass_check = 0 | ||||
|                 break | ||||
|         total_pass += pass_check | ||||
|     assert total_pass >= 1  # should be modified when ray>=0.9.0 release | ||||
|     assert total_pass >= 2 | ||||
| 
 | ||||
| 
 | ||||
| def test_vecenv(size=10, num=8, sleep=0.001): | ||||
|  | ||||
| @ -67,8 +67,7 @@ def test_sac_with_il(args=get_args()): | ||||
|     # model | ||||
|     net = Net(args.layer_num, args.state_shape, device=args.device) | ||||
|     actor = ActorProb( | ||||
|         net, args.action_shape, | ||||
|         args.max_action, args.device | ||||
|         net, args.action_shape, args.max_action, args.device, unbounded=True | ||||
|     ).to(args.device) | ||||
|     actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) | ||||
|     net = Net(args.layer_num, args.state_shape, | ||||
| @ -118,9 +117,10 @@ def test_sac_with_il(args=get_args()): | ||||
|     policy.eval() | ||||
|     if args.task == 'Pendulum-v0': | ||||
|         env.spec.reward_threshold = -300  # lower the goal | ||||
|     net = Actor(Net(1, args.state_shape, device=args.device), | ||||
|                 args.action_shape, args.max_action, args.device | ||||
|                 ).to(args.device) | ||||
|     net = Actor( | ||||
|         Net(1, args.state_shape, device=args.device), | ||||
|         args.action_shape, args.max_action, args.device | ||||
|     ).to(args.device) | ||||
|     optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) | ||||
|     il_policy = ImitationPolicy(net, optim, mode='continuous') | ||||
|     il_test_collector = Collector( | ||||
|  | ||||
| @ -214,7 +214,7 @@ class Collector(object): | ||||
|         # number of episodes, these envs will be recorded in this list, and | ||||
|         # they will not be stepped. | ||||
|         finished_env_ids = [] | ||||
|         reward_total = 0.0 | ||||
|         rewards = [] | ||||
|         whole_data = Batch() | ||||
|         if isinstance(n_episode, list): | ||||
|             assert len(n_episode) == self.get_env_num() | ||||
| @ -311,7 +311,8 @@ class Collector(object): | ||||
|                     if not (isinstance(n_episode, list) | ||||
|                             and episode_count[i] >= n_episode[i]): | ||||
|                         episode_count[i] += 1 | ||||
|                         reward_total += np.sum(self._cached_buf[i].rew, axis=0) | ||||
|                         rewards.append(self._rew_metric( | ||||
|                             np.sum(self._cached_buf[i].rew, axis=0))) | ||||
|                         step_count += len(self._cached_buf[i]) | ||||
|                         if self.buffer is not None: | ||||
|                             self.buffer.update(self._cached_buf[i]) | ||||
| @ -361,16 +362,13 @@ class Collector(object): | ||||
|         self.collect_step += step_count | ||||
|         self.collect_episode += episode_count | ||||
|         self.collect_time += duration | ||||
|         # average reward across the number of episodes | ||||
|         reward_avg = reward_total / episode_count | ||||
|         if np.asanyarray(reward_avg).size > 1:  # non-scalar reward_avg | ||||
|             reward_avg = self._rew_metric(reward_avg)  # type: ignore | ||||
|         return { | ||||
|             "n/ep": episode_count, | ||||
|             "n/st": step_count, | ||||
|             "v/st": step_count / duration, | ||||
|             "v/ep": episode_count / duration, | ||||
|             "rew": reward_avg, | ||||
|             "rew": np.mean(rewards), | ||||
|             "rew_std": np.std(rewards), | ||||
|             "len": step_count / episode_count, | ||||
|         } | ||||
| 
 | ||||
|  | ||||
| @ -71,8 +71,8 @@ def offpolicy_trainer( | ||||
| 
 | ||||
|     :return: See :func:`~tianshou.trainer.gather_info`. | ||||
|     """ | ||||
|     global_step = 0 | ||||
|     best_epoch, best_reward = -1, -1.0 | ||||
|     env_step, gradient_step = 0, 0 | ||||
|     best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 | ||||
|     stat: Dict[str, MovAvg] = {} | ||||
|     start_time = time.time() | ||||
|     train_collector.reset_stat() | ||||
| @ -86,13 +86,26 @@ def offpolicy_trainer( | ||||
|         ) as t: | ||||
|             while t.n < t.total: | ||||
|                 if train_fn: | ||||
|                     train_fn(epoch, global_step) | ||||
|                     train_fn(epoch, env_step) | ||||
|                 result = train_collector.collect(n_step=collect_per_step) | ||||
|                 data = {} | ||||
|                 env_step += int(result["n/st"]) | ||||
|                 data = { | ||||
|                     "env_step": str(env_step), | ||||
|                     "rew": f"{result['rew']:.2f}", | ||||
|                     "len": str(int(result["len"])), | ||||
|                     "n/ep": str(int(result["n/ep"])), | ||||
|                     "n/st": str(int(result["n/st"])), | ||||
|                     "v/ep": f"{result['v/ep']:.2f}", | ||||
|                     "v/st": f"{result['v/st']:.2f}", | ||||
|                 } | ||||
|                 if writer and env_step % log_interval == 0: | ||||
|                     for k in result.keys(): | ||||
|                         writer.add_scalar( | ||||
|                             "train/" + k, result[k], global_step=env_step) | ||||
|                 if test_in_train and stop_fn and stop_fn(result["rew"]): | ||||
|                     test_result = test_episode( | ||||
|                         policy, test_collector, test_fn, | ||||
|                         epoch, episode_per_test, writer, global_step) | ||||
|                         epoch, episode_per_test, writer, env_step) | ||||
|                     if stop_fn(test_result["rew"]): | ||||
|                         if save_fn: | ||||
|                             save_fn(policy) | ||||
| @ -101,42 +114,38 @@ def offpolicy_trainer( | ||||
|                         t.set_postfix(**data) | ||||
|                         return gather_info( | ||||
|                             start_time, train_collector, test_collector, | ||||
|                             test_result["rew"]) | ||||
|                             test_result["rew"], test_result["rew_std"]) | ||||
|                     else: | ||||
|                         policy.train() | ||||
|                 for i in range(update_per_step * min( | ||||
|                         result["n/st"] // collect_per_step, t.total - t.n)): | ||||
|                     global_step += collect_per_step | ||||
|                     gradient_step += 1 | ||||
|                     losses = policy.update(batch_size, train_collector.buffer) | ||||
|                     for k in result.keys(): | ||||
|                         data[k] = f"{result[k]:.2f}" | ||||
|                         if writer and global_step % log_interval == 0: | ||||
|                             writer.add_scalar("train/" + k, result[k], | ||||
|                                               global_step=global_step) | ||||
|                     for k in losses.keys(): | ||||
|                         if stat.get(k) is None: | ||||
|                             stat[k] = MovAvg() | ||||
|                         stat[k].add(losses[k]) | ||||
|                         data[k] = f"{stat[k].get():.6f}" | ||||
|                         if writer and global_step % log_interval == 0: | ||||
|                         if writer and gradient_step % log_interval == 0: | ||||
|                             writer.add_scalar( | ||||
|                                 k, stat[k].get(), global_step=global_step) | ||||
|                                 k, stat[k].get(), global_step=gradient_step) | ||||
|                     t.update(1) | ||||
|                     t.set_postfix(**data) | ||||
|             if t.n <= t.total: | ||||
|                 t.update() | ||||
|         # test | ||||
|         result = test_episode(policy, test_collector, test_fn, epoch, | ||||
|                               episode_per_test, writer, global_step) | ||||
|                               episode_per_test, writer, env_step) | ||||
|         if best_epoch == -1 or best_reward < result["rew"]: | ||||
|             best_reward = result["rew"] | ||||
|             best_reward, best_reward_std = result["rew"], result["rew_std"] | ||||
|             best_epoch = epoch | ||||
|             if save_fn: | ||||
|                 save_fn(policy) | ||||
|         if verbose: | ||||
|             print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, " | ||||
|                   f"best_reward: {best_reward:.6f} in #{best_epoch}") | ||||
|             print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " | ||||
|                   f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " | ||||
|                   f"{best_reward_std:.6f} in #{best_epoch}") | ||||
|         if stop_fn and stop_fn(best_reward): | ||||
|             break | ||||
|     return gather_info( | ||||
|         start_time, train_collector, test_collector, best_reward) | ||||
|     return gather_info(start_time, train_collector, test_collector, | ||||
|                        best_reward, best_reward_std) | ||||
|  | ||||
| @ -71,8 +71,8 @@ def onpolicy_trainer( | ||||
| 
 | ||||
|     :return: See :func:`~tianshou.trainer.gather_info`. | ||||
|     """ | ||||
|     global_step = 0 | ||||
|     best_epoch, best_reward = -1, -1.0 | ||||
|     env_step, gradient_step = 0, 0 | ||||
|     best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 | ||||
|     stat: Dict[str, MovAvg] = {} | ||||
|     start_time = time.time() | ||||
|     train_collector.reset_stat() | ||||
| @ -86,13 +86,26 @@ def onpolicy_trainer( | ||||
|         ) as t: | ||||
|             while t.n < t.total: | ||||
|                 if train_fn: | ||||
|                     train_fn(epoch, global_step) | ||||
|                     train_fn(epoch, env_step) | ||||
|                 result = train_collector.collect(n_episode=collect_per_step) | ||||
|                 data = {} | ||||
|                 env_step += int(result["n/st"]) | ||||
|                 data = { | ||||
|                     "env_step": str(env_step), | ||||
|                     "rew": f"{result['rew']:.2f}", | ||||
|                     "len": str(int(result["len"])), | ||||
|                     "n/ep": str(int(result["n/ep"])), | ||||
|                     "n/st": str(int(result["n/st"])), | ||||
|                     "v/ep": f"{result['v/ep']:.2f}", | ||||
|                     "v/st": f"{result['v/st']:.2f}", | ||||
|                 } | ||||
|                 if writer and env_step % log_interval == 0: | ||||
|                     for k in result.keys(): | ||||
|                         writer.add_scalar( | ||||
|                             "train/" + k, result[k], global_step=env_step) | ||||
|                 if test_in_train and stop_fn and stop_fn(result["rew"]): | ||||
|                     test_result = test_episode( | ||||
|                         policy, test_collector, test_fn, | ||||
|                         epoch, episode_per_test, writer, global_step) | ||||
|                         epoch, episode_per_test, writer, env_step) | ||||
|                     if stop_fn(test_result["rew"]): | ||||
|                         if save_fn: | ||||
|                             save_fn(policy) | ||||
| @ -101,47 +114,41 @@ def onpolicy_trainer( | ||||
|                         t.set_postfix(**data) | ||||
|                         return gather_info( | ||||
|                             start_time, train_collector, test_collector, | ||||
|                             test_result["rew"]) | ||||
|                             test_result["rew"], test_result["rew_std"]) | ||||
|                     else: | ||||
|                         policy.train() | ||||
|                 losses = policy.update( | ||||
|                     0, train_collector.buffer, | ||||
|                     batch_size=batch_size, repeat=repeat_per_collect) | ||||
|                 train_collector.reset_buffer() | ||||
|                 step = 1 | ||||
|                 for v in losses.values(): | ||||
|                     if isinstance(v, list): | ||||
|                         step = max(step, len(v)) | ||||
|                 global_step += step * collect_per_step | ||||
|                 for k in result.keys(): | ||||
|                     data[k] = f"{result[k]:.2f}" | ||||
|                     if writer and global_step % log_interval == 0: | ||||
|                         writer.add_scalar( | ||||
|                             "train/" + k, result[k], global_step=global_step) | ||||
|                 step = max([1] + [ | ||||
|                     len(v) for v in losses.values() if isinstance(v, list)]) | ||||
|                 gradient_step += step | ||||
|                 for k in losses.keys(): | ||||
|                     if stat.get(k) is None: | ||||
|                         stat[k] = MovAvg() | ||||
|                     stat[k].add(losses[k]) | ||||
|                     data[k] = f"{stat[k].get():.6f}" | ||||
|                     if writer and global_step % log_interval == 0: | ||||
|                     if writer and gradient_step % log_interval == 0: | ||||
|                         writer.add_scalar( | ||||
|                             k, stat[k].get(), global_step=global_step) | ||||
|                             k, stat[k].get(), global_step=gradient_step) | ||||
|                 t.update(step) | ||||
|                 t.set_postfix(**data) | ||||
|             if t.n <= t.total: | ||||
|                 t.update() | ||||
|         # test | ||||
|         result = test_episode(policy, test_collector, test_fn, epoch, | ||||
|                               episode_per_test, writer, global_step) | ||||
|                               episode_per_test, writer, env_step) | ||||
|         if best_epoch == -1 or best_reward < result["rew"]: | ||||
|             best_reward = result["rew"] | ||||
|             best_reward, best_reward_std = result["rew"], result["rew_std"] | ||||
|             best_epoch = epoch | ||||
|             if save_fn: | ||||
|                 save_fn(policy) | ||||
|         if verbose: | ||||
|             print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, " | ||||
|                   f"best_reward: {best_reward:.6f} in #{best_epoch}") | ||||
|             print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " | ||||
|                   f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " | ||||
|                   f"{best_reward_std:.6f} in #{best_epoch}") | ||||
|         if stop_fn and stop_fn(best_reward): | ||||
|             break | ||||
|     return gather_info( | ||||
|         start_time, train_collector, test_collector, best_reward) | ||||
|     return gather_info(start_time, train_collector, test_collector, | ||||
|                        best_reward, best_reward_std) | ||||
|  | ||||
| @ -39,6 +39,7 @@ def gather_info( | ||||
|     train_c: Collector, | ||||
|     test_c: Collector, | ||||
|     best_reward: float, | ||||
|     best_reward_std: float, | ||||
| ) -> Dict[str, Union[float, str]]: | ||||
|     """A simple wrapper of gathering information from collectors. | ||||
| 
 | ||||
| @ -72,5 +73,6 @@ def gather_info( | ||||
|         "test_time": f"{test_c.collect_time:.2f}s", | ||||
|         "test_speed": f"{test_speed:.2f} step/s", | ||||
|         "best_reward": best_reward, | ||||
|         "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", | ||||
|         "duration": f"{duration:.2f}s", | ||||
|     } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user