From 02c3d45fcfc488151dd6ac7e3e905b55031c8d8f Mon Sep 17 00:00:00 2001 From: NM512 Date: Sun, 21 May 2023 08:17:47 +0900 Subject: [PATCH] modification of expl. --- dreamer.py | 2 +- exploration.py | 29 +++++++++++++++++++++-------- models.py | 13 ++++++------- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/dreamer.py b/dreamer.py index 5a23d88..12ef8a2 100644 --- a/dreamer.py +++ b/dreamer.py @@ -61,7 +61,7 @@ class Dreamer(nn.Module): reward = lambda f, s, a: self._wm.heads["reward"](f).mean self._expl_behavior = dict( greedy=lambda: self._task_behavior, - random=lambda: expl.Random(config), + random=lambda: expl.Random(config, act_space), plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), )[config.expl_behavior]().to(self._config.device) diff --git a/exploration.py b/exploration.py index f57877c..f195bb8 100644 --- a/exploration.py +++ b/exploration.py @@ -8,22 +8,35 @@ import tools class Random(nn.Module): - def __init__(self, config): + def __init__(self, config, act_space): + super(Random, self).__init__() self._config = config + self._act_space = act_space def actor(self, feat): - shape = feat.shape[:-1] + [self._config.num_actions] if self._config.actor_dist == "onehot": - return tools.OneHotDist(torch.zeros(shape)) + return tools.OneHotDist( + torch.zeros(self._config.num_actions) + .repeat(self._config.envs, 1) + .to(self._config.device) + ) else: - ones = torch.ones(shape) - return tools.ContDist(torchd.uniform.Uniform(-ones, ones)) + return torchd.independent.Independent( + torchd.uniform.Uniform( + torch.Tensor(self._act_space.low) + .repeat(self._config.envs, 1) + .to(self._config.device), + torch.Tensor(self._act_space.high) + .repeat(self._config.envs, 1) + .to(self._config.device), + ), + 1, + ) - def train(self, start, context): + def train(self, start, context, data): return None, {} -# class Plan2Explore(tools.Module): class Plan2Explore(nn.Module): def __init__(self, config, world_model, reward=None): super(Plan2Explore, self).__init__() @@ -39,7 +52,7 @@ class Plan2Explore(nn.Module): feat_size = config.dyn_stoch + config.dyn_deter stoch = config.dyn_stoch size = { - "embed": 32 * config.cnn_depth, + "embed": world_model.embed_size, "stoch": stoch, "deter": config.dyn_deter, "feat": config.dyn_stoch + config.dyn_deter, diff --git a/models.py b/models.py index 163765c..8402d10 100644 --- a/models.py +++ b/models.py @@ -36,7 +36,7 @@ class WorldModel(nn.Module): self._config = config shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()} self.encoder = networks.MultiEncoder(shapes, **config.encoder) - embed_size = self.encoder.outdim + self.embed_size = self.encoder.outdim self.dynamics = networks.RSSM( config.dyn_stoch, config.dyn_deter, @@ -56,7 +56,7 @@ class WorldModel(nn.Module): config.unimix_ratio, config.initial, config.num_actions, - embed_size, + self.embed_size, config.device, ) self.heads = nn.ModuleDict() @@ -228,7 +228,7 @@ class ImagBehavior(nn.Module): else: feat_size = config.dyn_stoch + config.dyn_deter self.actor = networks.ActionHead( - feat_size, # pytorch version + feat_size, config.num_actions, config.actor_layers, config.units, @@ -244,7 +244,7 @@ class ImagBehavior(nn.Module): ) if config.value_head == "symlog_disc": self.value = networks.MLP( - feat_size, # pytorch version + feat_size, (255,), config.value_layers, config.units, @@ -256,7 +256,7 @@ class ImagBehavior(nn.Module): ) else: self.value = networks.MLP( - feat_size, # pytorch version + feat_size, [], config.value_layers, config.units, @@ -356,7 +356,7 @@ class ImagBehavior(nn.Module): ) else: metrics.update(tools.tensorstats(imag_action, "imag_action")) - metrics["actor_ent"] = to_np(torch.mean(actor_ent)) + metrics["actor_entropy"] = to_np(torch.mean(actor_ent)) with tools.RequiresGrad(self): metrics.update(self._actor_opt(actor_loss, self.actor.parameters())) metrics.update(self._value_opt(value_loss, self.value.parameters())) @@ -462,7 +462,6 @@ class ImagBehavior(nn.Module): if not self._config.future_entropy and (self._config.actor_entropy() > 0): actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None] actor_target += actor_entropy - metrics["actor_entropy"] = to_np(torch.mean(actor_entropy)) if not self._config.future_entropy and (self._config.actor_state_entropy() > 0): state_entropy = self._config.actor_state_entropy() * state_ent[:-1] actor_target += state_entropy