From 60cfc373f8afb0eaa8716a3d21f14f01af0d74dc Mon Sep 17 00:00:00 2001 From: danagi <420147879@qq.com> Date: Sat, 27 Jun 2020 21:40:09 +0800 Subject: [PATCH] fix #98, support #99 (#102) * Add auto alpha tuning and exploration noise for sac. Add class BaseNoise and GaussianNoise for the concept of exploration noise. Add new test for sac tested in MountainCarContinuous-v0, which should benefits from the two above new feature. * add exploration noise to collector, fix example to adapt modification * fix #98 * enable off-policy to update multiple times in one step. (#99) --- examples/ant_v2_sac.py | 2 +- examples/continuous_net.py | 6 ++++-- examples/halfcheetahBullet_v0_sac.py | 2 +- examples/sac_mcc.py | 2 +- tianshou/trainer/offpolicy.py | 8 ++++++-- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/ant_v2_sac.py b/examples/ant_v2_sac.py index 8632e92..1d28615 100644 --- a/examples/ant_v2_sac.py +++ b/examples/ant_v2_sac.py @@ -60,7 +60,7 @@ def test_sac(args=get_args()): # model actor = ActorProb( args.layer_num, args.state_shape, args.action_shape, - args.max_action, args.device + args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) critic1 = Critic( diff --git a/examples/continuous_net.py b/examples/continuous_net.py index 1598a1f..c76ab17 100644 --- a/examples/continuous_net.py +++ b/examples/continuous_net.py @@ -28,7 +28,7 @@ class Actor(nn.Module): class ActorProb(nn.Module): def __init__(self, layer_num, state_shape, action_shape, - max_action, device='cpu'): + max_action, device='cpu', unbounded=False): super().__init__() self.device = device self.model = [ @@ -40,6 +40,7 @@ class ActorProb(nn.Module): self.mu = nn.Linear(128, np.prod(action_shape)) self.sigma = nn.Linear(128, np.prod(action_shape)) self._max = max_action + self._unbounded = unbounded def forward(self, s, **kwargs): if not isinstance(s, torch.Tensor): @@ -47,7 +48,8 @@ class ActorProb(nn.Module): batch = s.shape[0] s = s.view(batch, -1) logits = self.model(s) - mu = self._max * torch.tanh(self.mu(logits)) + if not self._unbounded: + mu = self._max * torch.tanh(self.mu(logits)) sigma = torch.exp(self.sigma(logits)) return (mu, sigma), None diff --git a/examples/halfcheetahBullet_v0_sac.py b/examples/halfcheetahBullet_v0_sac.py index 1c90708..57ca8ba 100644 --- a/examples/halfcheetahBullet_v0_sac.py +++ b/examples/halfcheetahBullet_v0_sac.py @@ -68,7 +68,7 @@ def test_sac(args=get_args()): # model actor = ActorProb( args.layer_num, args.state_shape, args.action_shape, - args.max_action, args.device + args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) critic1 = Critic( diff --git a/examples/sac_mcc.py b/examples/sac_mcc.py index ed8ca46..6455975 100644 --- a/examples/sac_mcc.py +++ b/examples/sac_mcc.py @@ -64,7 +64,7 @@ def test_sac(args=get_args()): # model actor = ActorProb( args.layer_num, args.state_shape, args.action_shape, - args.max_action, args.device + args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) critic1 = Critic( diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index e760a5d..1e12721 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -18,6 +18,7 @@ def offpolicy_trainer( collect_per_step: int, episode_per_test: Union[int, List[int]], batch_size: int, + update_per_step: int = 1, train_fn: Optional[Callable[[int], None]] = None, test_fn: Optional[Callable[[int], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, @@ -42,10 +43,13 @@ def offpolicy_trainer( in one epoch. :param int collect_per_step: the number of frames the collector would collect before the network update. In other words, collect some frames - and do one policy network update. + and do some policy network update. :param episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to feed in the policy network. + :param int update_per_step: the number of times the policy network would + be updated after frames be collected. In other words, collect some + frames and do some policy network update. :param function train_fn: a function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. @@ -98,7 +102,7 @@ def offpolicy_trainer( policy.train() if train_fn: train_fn(epoch) - for i in range(min( + for i in range(update_per_step * min( result['n/st'] // collect_per_step, t.total - t.n)): global_step += 1 losses = policy.learn(train_collector.sample(batch_size))