fix #98, support #99 (#102)

* Add auto alpha tuning and exploration noise for sac.
Add class BaseNoise and GaussianNoise for the concept of exploration noise.
Add new test for sac tested in MountainCarContinuous-v0,
which should benefits from the two above new feature.

* add exploration noise to collector, fix example to adapt modification

* fix #98

* enable off-policy to update multiple times in one step. (#99)
This commit is contained in:
danagi 2020-06-27 21:40:09 +08:00 committed by GitHub
parent a951a32487
commit 60cfc373f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 13 additions and 7 deletions

View File

@ -60,7 +60,7 @@ def test_sac(args=get_args()):
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(

View File

@ -28,7 +28,7 @@ class Actor(nn.Module):
class ActorProb(nn.Module):
def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu'):
max_action, device='cpu', unbounded=False):
super().__init__()
self.device = device
self.model = [
@ -40,6 +40,7 @@ class ActorProb(nn.Module):
self.mu = nn.Linear(128, np.prod(action_shape))
self.sigma = nn.Linear(128, np.prod(action_shape))
self._max = max_action
self._unbounded = unbounded
def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
@ -47,7 +48,8 @@ class ActorProb(nn.Module):
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
mu = self._max * torch.tanh(self.mu(logits))
if not self._unbounded:
mu = self._max * torch.tanh(self.mu(logits))
sigma = torch.exp(self.sigma(logits))
return (mu, sigma), None

View File

@ -68,7 +68,7 @@ def test_sac(args=get_args()):
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(

View File

@ -64,7 +64,7 @@ def test_sac(args=get_args()):
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(

View File

@ -18,6 +18,7 @@ def offpolicy_trainer(
collect_per_step: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
update_per_step: int = 1,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
@ -42,10 +43,13 @@ def offpolicy_trainer(
in one epoch.
:param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some frames
and do one policy network update.
and do some policy network update.
:param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param int update_per_step: the number of times the policy network would
be updated after frames be collected. In other words, collect some
frames and do some policy network update.
:param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this
epoch.
@ -98,7 +102,7 @@ def offpolicy_trainer(
policy.train()
if train_fn:
train_fn(epoch)
for i in range(min(
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += 1
losses = policy.learn(train_collector.sample(batch_size))