diff --git a/dreamer.py b/dreamer.py index 591994b..ba4bc17 100644 --- a/dreamer.py +++ b/dreamer.py @@ -62,7 +62,7 @@ class Dreamer(nn.Module): greedy=lambda: self._task_behavior, random=lambda: expl.Random(config), plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), - )[config.expl_behavior]() + )[config.expl_behavior]().to(self._config.device) def __call__(self, obs, reset, state=None, reward=None, training=True): step = self._step diff --git a/exploration.py b/exploration.py index 77cbc3b..40985cd 100644 --- a/exploration.py +++ b/exploration.py @@ -26,83 +26,94 @@ class Random(nn.Module): # class Plan2Explore(tools.Module): class Plan2Explore(nn.Module): def __init__(self, config, world_model, reward=None): + super(Plan2Explore, self).__init__() self._config = config + self._use_amp = True if config.precision == 16 else False self._reward = reward self._behavior = models.ImagBehavior(config, world_model) self.actor = self._behavior.actor - stoch_size = config.dyn_stoch if config.dyn_discrete: - stoch_size *= config.dyn_discrete + feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter + stoch = config.dyn_stoch * config.dyn_discrete + else: + feat_size = config.dyn_stoch + config.dyn_deter + stoch = config.dyn_stoch size = { "embed": 32 * config.cnn_depth, - "stoch": stoch_size, + "stoch": stoch, "deter": config.dyn_deter, "feat": config.dyn_stoch + config.dyn_deter, }[self._config.disag_target] kw = dict( - inp_dim=config.dyn_stoch, # pytorch version + inp_dim=feat_size, # pytorch version shape=size, layers=config.disag_layers, units=config.disag_units, act=config.act, ) - self._networks = [networks.DenseHead(**kw) for _ in range(config.disag_models)] - self._opt = tools.optimizer( - config.opt, + self._networks = nn.ModuleList( + [networks.DenseHead(**kw) for _ in range(config.disag_models)] + ) + kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) + self._model_opt = tools.Optimizer( + "explorer", self.parameters(), config.model_lr, config.opt_eps, - config.weight_decay, + config.grad_clip, + **kw ) - # self._opt = tools.Optimizer( - # 'ensemble', config.model_lr, config.opt_eps, config.grad_clip, - # config.weight_decay, opt=config.opt) def train(self, start, context, data): - metrics = {} - stoch = start["stoch"] - if self._config.dyn_discrete: - stoch = tf.reshape( - stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]) - ) - target = { - "embed": context["embed"], - "stoch": stoch, - "deter": start["deter"], - "feat": context["feat"], - }[self._config.disag_target] - inputs = context["feat"] - if self._config.disag_action_cond: - inputs = tf.concat([inputs, data["action"]], -1) - metrics.update(self._train_ensemble(inputs, target)) - metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1]) + with tools.RequiresGrad(self): + metrics = {} + stoch = start["stoch"] + if self._config.dyn_discrete: + stoch = torch.reshape( + stoch, (stoch.shape[:-2] + ((stoch.shape[-2] * stoch.shape[-1]),)) + ) + target = { + "embed": context["embed"], + "stoch": stoch, + "deter": start["deter"], + "feat": context["feat"], + }[self._config.disag_target] + inputs = context["feat"] + if self._config.disag_action_cond: + inputs = torch.concat([inputs, data["action"]], -1) + metrics.update(self._train_ensemble(inputs, target)) + metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1]) return None, metrics def _intrinsic_reward(self, feat, state, action): inputs = feat if self._config.disag_action_cond: - inputs = tf.concat([inputs, action], -1) - preds = [head(inputs, tf.float32).mean() for head in self._networks] - disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1) + inputs = torch.concat([inputs, action], -1) + preds = torch.cat( + [head(inputs, torch.float32).mode()[None] for head in self._networks], 0 + ) + disag = torch.mean(torch.std(preds, 0), -1)[..., None] if self._config.disag_log: - disag = tf.math.log(disag) + disag = torch.log(disag) reward = self._config.expl_intr_scale * disag if self._config.expl_extr_scale: - reward += tf.cast( + reward += torch.cast( self._config.expl_extr_scale * self._reward(feat, state, action), - tf.float32, + torch.float32, ) return reward def _train_ensemble(self, inputs, targets): - if self._config.disag_offset: - targets = targets[:, self._config.disag_offset :] - inputs = inputs[:, : -self._config.disag_offset] - targets = tf.stop_gradient(targets) - inputs = tf.stop_gradient(inputs) - with tf.GradientTape() as tape: + with torch.cuda.amp.autocast(self._use_amp): + if self._config.disag_offset: + targets = targets[:, self._config.disag_offset :] + inputs = inputs[:, : -self._config.disag_offset] + targets = targets.detach() + inputs = inputs.detach() preds = [head(inputs) for head in self._networks] - likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds] - loss = -tf.cast(tf.reduce_sum(likes), tf.float32) - metrics = self._opt(tape, loss, self._networks) + likes = torch.cat( + [torch.mean(pred.log_prob(targets))[None] for pred in preds], 0 + ) + loss = -torch.mean(likes) + metrics = self._model_opt(loss, self.parameters()) return metrics diff --git a/tools.py b/tools.py index 419af2c..a73b238 100644 --- a/tools.py +++ b/tools.py @@ -548,7 +548,7 @@ class Optimizer(): lr=lr, eps=eps), 'nadam': lambda: NotImplemented( - f'{config.opt} is not implemented'), + f'{opt} is not implemented'), 'adamax': lambda: torch.optim.Adamax(parameters, lr=lr, eps=eps),