* Add auto alpha tuning and exploration noise for sac. Add class BaseNoise and GaussianNoise for the concept of exploration noise. Add new test for sac tested in MountainCarContinuous-v0, which should benefits from the two above new feature. * add exploration noise to collector, fix example to adapt modification * fix #98 * enable off-policy to update multiple times in one step. (#99)
This commit is contained in:
parent
a951a32487
commit
60cfc373f8
@ -60,7 +60,7 @@ def test_sac(args=get_args()):
|
||||
# model
|
||||
actor = ActorProb(
|
||||
args.layer_num, args.state_shape, args.action_shape,
|
||||
args.max_action, args.device
|
||||
args.max_action, args.device, unbounded=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
critic1 = Critic(
|
||||
|
@ -28,7 +28,7 @@ class Actor(nn.Module):
|
||||
|
||||
class ActorProb(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape,
|
||||
max_action, device='cpu'):
|
||||
max_action, device='cpu', unbounded=False):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
@ -40,6 +40,7 @@ class ActorProb(nn.Module):
|
||||
self.mu = nn.Linear(128, np.prod(action_shape))
|
||||
self.sigma = nn.Linear(128, np.prod(action_shape))
|
||||
self._max = max_action
|
||||
self._unbounded = unbounded
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
@ -47,7 +48,8 @@ class ActorProb(nn.Module):
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
logits = self.model(s)
|
||||
mu = self._max * torch.tanh(self.mu(logits))
|
||||
if not self._unbounded:
|
||||
mu = self._max * torch.tanh(self.mu(logits))
|
||||
sigma = torch.exp(self.sigma(logits))
|
||||
return (mu, sigma), None
|
||||
|
||||
|
@ -68,7 +68,7 @@ def test_sac(args=get_args()):
|
||||
# model
|
||||
actor = ActorProb(
|
||||
args.layer_num, args.state_shape, args.action_shape,
|
||||
args.max_action, args.device
|
||||
args.max_action, args.device, unbounded=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
critic1 = Critic(
|
||||
|
@ -64,7 +64,7 @@ def test_sac(args=get_args()):
|
||||
# model
|
||||
actor = ActorProb(
|
||||
args.layer_num, args.state_shape, args.action_shape,
|
||||
args.max_action, args.device
|
||||
args.max_action, args.device, unbounded=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
critic1 = Critic(
|
||||
|
@ -18,6 +18,7 @@ def offpolicy_trainer(
|
||||
collect_per_step: int,
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
update_per_step: int = 1,
|
||||
train_fn: Optional[Callable[[int], None]] = None,
|
||||
test_fn: Optional[Callable[[int], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
@ -42,10 +43,13 @@ def offpolicy_trainer(
|
||||
in one epoch.
|
||||
:param int collect_per_step: the number of frames the collector would
|
||||
collect before the network update. In other words, collect some frames
|
||||
and do one policy network update.
|
||||
and do some policy network update.
|
||||
:param episode_per_test: the number of episodes for one policy evaluation.
|
||||
:param int batch_size: the batch size of sample data, which is going to
|
||||
feed in the policy network.
|
||||
:param int update_per_step: the number of times the policy network would
|
||||
be updated after frames be collected. In other words, collect some
|
||||
frames and do some policy network update.
|
||||
:param function train_fn: a function receives the current number of epoch
|
||||
index and performs some operations at the beginning of training in this
|
||||
epoch.
|
||||
@ -98,7 +102,7 @@ def offpolicy_trainer(
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
for i in range(min(
|
||||
for i in range(update_per_step * min(
|
||||
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||
global_step += 1
|
||||
losses = policy.learn(train_collector.sample(batch_size))
|
||||
|
Loading…
x
Reference in New Issue
Block a user