fix #98, support #99 (#102)

* Add auto alpha tuning and exploration noise for sac.
Add class BaseNoise and GaussianNoise for the concept of exploration noise.
Add new test for sac tested in MountainCarContinuous-v0,
which should benefits from the two above new feature.

* add exploration noise to collector, fix example to adapt modification

* fix #98

* enable off-policy to update multiple times in one step. (#99)
This commit is contained in:
danagi 2020-06-27 21:40:09 +08:00 committed by GitHub
parent a951a32487
commit 60cfc373f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 13 additions and 7 deletions

View File

@ -60,7 +60,7 @@ def test_sac(args=get_args()):
# model # model
actor = ActorProb( actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape, args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device args.max_action, args.device, unbounded=True
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( critic1 = Critic(

View File

@ -28,7 +28,7 @@ class Actor(nn.Module):
class ActorProb(nn.Module): class ActorProb(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu'): max_action, device='cpu', unbounded=False):
super().__init__() super().__init__()
self.device = device self.device = device
self.model = [ self.model = [
@ -40,6 +40,7 @@ class ActorProb(nn.Module):
self.mu = nn.Linear(128, np.prod(action_shape)) self.mu = nn.Linear(128, np.prod(action_shape))
self.sigma = nn.Linear(128, np.prod(action_shape)) self.sigma = nn.Linear(128, np.prod(action_shape))
self._max = max_action self._max = max_action
self._unbounded = unbounded
def forward(self, s, **kwargs): def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor): if not isinstance(s, torch.Tensor):
@ -47,6 +48,7 @@ class ActorProb(nn.Module):
batch = s.shape[0] batch = s.shape[0]
s = s.view(batch, -1) s = s.view(batch, -1)
logits = self.model(s) logits = self.model(s)
if not self._unbounded:
mu = self._max * torch.tanh(self.mu(logits)) mu = self._max * torch.tanh(self.mu(logits))
sigma = torch.exp(self.sigma(logits)) sigma = torch.exp(self.sigma(logits))
return (mu, sigma), None return (mu, sigma), None

View File

@ -68,7 +68,7 @@ def test_sac(args=get_args()):
# model # model
actor = ActorProb( actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape, args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device args.max_action, args.device, unbounded=True
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( critic1 = Critic(

View File

@ -64,7 +64,7 @@ def test_sac(args=get_args()):
# model # model
actor = ActorProb( actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape, args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device args.max_action, args.device, unbounded=True
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( critic1 = Critic(

View File

@ -18,6 +18,7 @@ def offpolicy_trainer(
collect_per_step: int, collect_per_step: int,
episode_per_test: Union[int, List[int]], episode_per_test: Union[int, List[int]],
batch_size: int, batch_size: int,
update_per_step: int = 1,
train_fn: Optional[Callable[[int], None]] = None, train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None, test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None, stop_fn: Optional[Callable[[float], bool]] = None,
@ -42,10 +43,13 @@ def offpolicy_trainer(
in one epoch. in one epoch.
:param int collect_per_step: the number of frames the collector would :param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some frames collect before the network update. In other words, collect some frames
and do one policy network update. and do some policy network update.
:param episode_per_test: the number of episodes for one policy evaluation. :param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to :param int batch_size: the batch size of sample data, which is going to
feed in the policy network. feed in the policy network.
:param int update_per_step: the number of times the policy network would
be updated after frames be collected. In other words, collect some
frames and do some policy network update.
:param function train_fn: a function receives the current number of epoch :param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this index and performs some operations at the beginning of training in this
epoch. epoch.
@ -98,7 +102,7 @@ def offpolicy_trainer(
policy.train() policy.train()
if train_fn: if train_fn:
train_fn(epoch) train_fn(epoch)
for i in range(min( for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)): result['n/st'] // collect_per_step, t.total - t.n)):
global_step += 1 global_step += 1
losses = policy.learn(train_collector.sample(batch_size)) losses = policy.learn(train_collector.sample(batch_size))