diff --git a/configs.yaml b/configs.yaml index 35e3fea..700712e 100644 --- a/configs.yaml +++ b/configs.yaml @@ -59,9 +59,9 @@ defaults: {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse} value_head: 'symlog_disc' reward_head: 'symlog_disc' - dyn_scale: '0.5' - rep_scale: '0.1' - kl_free: '1.0' + dyn_scale: 0.5 + rep_scale: 0.1 + kl_free: 1.0 cont_scale: 1.0 reward_scale: 1.0 weight_decay: 0.0 @@ -93,10 +93,10 @@ defaults: discount_lambda: 0.95 imag_horizon: 15 imag_gradient: 'dynamics' - imag_gradient_mix: '0.0' + imag_gradient_mix: 0.0 imag_sample: True actor_dist: 'normal' - actor_entropy: '3e-4' + actor_entropy: 3e-4 actor_state_entropy: 0.0 actor_init_std: 1.0 actor_min_std: 0.1 diff --git a/dreamer.py b/dreamer.py index b3c5705..2b61b90 100644 --- a/dreamer.py +++ b/dreamer.py @@ -40,16 +40,6 @@ class Dreamer(nn.Module): # this is update step self._step = logger.step // config.action_repeat self._update_count = 0 - # Schedules. - config.actor_entropy = lambda x=config.actor_entropy: tools.schedule( - x, self._step - ) - config.actor_state_entropy = ( - lambda x=config.actor_state_entropy: tools.schedule(x, self._step) - ) - config.imag_gradient_mix = lambda x=config.imag_gradient_mix: tools.schedule( - x, self._step - ) self._dataset = dataset self._wm = models.WorldModel(obs_space, act_space, self._step, config) self._task_behavior = models.ImagBehavior( diff --git a/models.py b/models.py index b76d842..388fb36 100644 --- a/models.py +++ b/models.py @@ -128,9 +128,9 @@ class WorldModel(nn.Module): post, prior = self.dynamics.observe( embed, data["action"], data["is_first"] ) - kl_free = tools.schedule(self._config.kl_free, self._step) - dyn_scale = tools.schedule(self._config.dyn_scale, self._step) - rep_scale = tools.schedule(self._config.rep_scale, self._step) + kl_free = self._config.kl_free + dyn_scale = self._config.dyn_scale + rep_scale = self._config.rep_scale kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss( post, prior, kl_free, dyn_scale, rep_scale ) @@ -393,10 +393,10 @@ class ImagBehavior(nn.Module): discount = self._config.discount * self._world_model.heads["cont"](inp).mean else: discount = self._config.discount * torch.ones_like(reward) - if self._config.future_entropy and self._config.actor_entropy() > 0: - reward += self._config.actor_entropy() * actor_ent - if self._config.future_entropy and self._config.actor_state_entropy() > 0: - reward += self._config.actor_state_entropy() * state_ent + if self._config.future_entropy and self._config.actor_entropy > 0: + reward += self._config.actor_entropy * actor_ent + if self._config.future_entropy and self._config.actor_state_entropy > 0: + reward += self._config.actor_state_entropy * state_ent value = self.value(imag_feat).mode() target = tools.lambda_return( reward[1:], @@ -450,16 +450,16 @@ class ImagBehavior(nn.Module): policy.log_prob(imag_action)[:-1][:, :, None] * (target - self.value(imag_feat[:-1]).mode()).detach() ) - mix = self._config.imag_gradient_mix() + mix = self._config.imag_gradient_mix actor_target = mix * target + (1 - mix) * actor_target metrics["imag_gradient_mix"] = mix else: raise NotImplementedError(self._config.imag_gradient) - if not self._config.future_entropy and (self._config.actor_entropy() > 0): - actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None] + if not self._config.future_entropy and self._config.actor_entropy > 0: + actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None] actor_target += actor_entropy - if not self._config.future_entropy and (self._config.actor_state_entropy() > 0): - state_entropy = self._config.actor_state_entropy() * state_ent[:-1] + if not self._config.future_entropy and (self._config.actor_state_entropy > 0): + state_entropy = self._config.actor_state_entropy * state_ent[:-1] actor_target += state_entropy metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy)) actor_loss = -torch.mean(weights[:-1] * actor_target) diff --git a/tools.py b/tools.py index 69b9edd..f9edc1d 100644 --- a/tools.py +++ b/tools.py @@ -899,33 +899,6 @@ class Until: return step < self._until -def schedule(string, step): - try: - return float(string) - except ValueError: - match = re.match(r"linear\((.+),(.+),(.+)\)", string) - if match: - initial, final, duration = [float(group) for group in match.groups()] - mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0] - return (1 - mix) * initial + mix * final - match = re.match(r"warmup\((.+),(.+)\)", string) - if match: - warmup, value = [float(group) for group in match.groups()] - scale = torch.clip(step / warmup, 0, 1) - return scale * value - match = re.match(r"exp\((.+),(.+),(.+)\)", string) - if match: - initial, final, halflife = [float(group) for group in match.groups()] - return (initial - final) * 0.5 ** (step / halflife) + final - match = re.match(r"horizon\((.+),(.+),(.+)\)", string) - if match: - initial, final, duration = [float(group) for group in match.groups()] - mix = torch.clip(step / duration, 0, 1) - horizon = (1 - mix) * initial + mix * final - return 1 - 1 / horizon - raise NotImplementedError(string) - - def weight_init(m): if isinstance(m, nn.Linear): in_num = m.in_features