From a9e85e8b7ce8e4a3e8e6805d5cd918217599daee Mon Sep 17 00:00:00 2001 From: NM512 Date: Fri, 5 Jan 2024 10:46:54 +0900 Subject: [PATCH] modified weight initialization --- models.py | 6 ++-- networks.py | 81 +++++++++++++++++++++++++++++------------------------ tools.py | 14 ++++++++- 3 files changed, 61 insertions(+), 40 deletions(-) diff --git a/models.py b/models.py index a97fbc8..69f6b53 100644 --- a/models.py +++ b/models.py @@ -179,7 +179,7 @@ class WorldModel(nn.Module): # this function is called during both rollout and training def preprocess(self, obs): obs = obs.copy() - obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5 + obs["image"] = torch.Tensor(obs["image"]) / 255.0 if "discount" in obs: obs["discount"] *= self._config.discount # (batch_size, batch_length) -> (batch_size, batch_length, 1) @@ -209,8 +209,8 @@ class WorldModel(nn.Module): reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode() # observed image is given until 5 steps model = torch.cat([recon[:, :5], openl], 1) - truth = data["image"][:6] + 0.5 - model = model + 0.5 + truth = data["image"][:6] + model = model error = (model - truth + 1.0) / 2.0 return torch.cat([truth, model, error], 2) diff --git a/networks.py b/networks.py index f43630c..3616e2e 100644 --- a/networks.py +++ b/networks.py @@ -68,9 +68,8 @@ class RSSM(nn.Module): inp_layers.append(act()) if i == 0: inp_dim = self._hidden - self._inp_layers = nn.Sequential(*inp_layers) - self._inp_layers.apply(tools.weight_init) - + self._img_in_layers = nn.Sequential(*inp_layers) + self._img_in_layers.apply(tools.weight_init) if cell == "gru": self._cell = GRUCell(self._hidden, self._deter) self._cell.apply(tools.weight_init) @@ -106,15 +105,17 @@ class RSSM(nn.Module): self._obs_out_layers.apply(tools.weight_init) if self._discrete: - self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) - self._ims_stat_layer.apply(tools.weight_init) + self._imgs_stat_layer = nn.Linear( + self._hidden, self._stoch * self._discrete + ) + self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0)) self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) - self._obs_stat_layer.apply(tools.weight_init) + self._obs_stat_layer.apply(tools.uniform_weight_init(1.0)) else: - self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) - self._ims_stat_layer.apply(tools.weight_init) + self._imgs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) + self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0)) self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) - self._obs_stat_layer.apply(tools.weight_init) + self._obs_stat_layer.apply(tools.uniform_weight_init(1.0)) if self._initial == "learned": self.W = torch.nn.Parameter( @@ -260,7 +261,7 @@ class RSSM(nn.Module): else: x = torch.cat([prev_stoch, prev_action], -1) # (batch, stoch * discrete_num + action, embed) -> (batch, hidden) - x = self._inp_layers(x) + x = self._img_in_layers(x) for _ in range(self._rec_depth): # rec depth is not correctly implemented deter = prev_state["deter"] # (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter) @@ -286,7 +287,7 @@ class RSSM(nn.Module): def _suff_stats_layer(self, name, x): if self._discrete: if name == "ims": - x = self._ims_stat_layer(x) + x = self._imgs_stat_layer(x) elif name == "obs": x = self._obs_stat_layer(x) else: @@ -295,7 +296,7 @@ class RSSM(nn.Module): return {"logit": logit} else: if name == "ims": - x = self._ims_stat_layer(x) + x = self._imgs_stat_layer(x) elif name == "obs": x = self._obs_stat_layer(x) else: @@ -386,6 +387,7 @@ class MultiEncoder(nn.Module): act, norm, symlog_inputs=symlog_inputs, + name="Encoder", ) self.outdim += mlp_units @@ -418,6 +420,7 @@ class MultiDecoder(nn.Module): cnn_sigmoid, image_dist, vector_dist, + outscale, ): super(MultiDecoder, self).__init__() excluded = ("is_first", "is_last", "is_terminal") @@ -444,6 +447,7 @@ class MultiDecoder(nn.Module): norm, kernel_size, minres, + outscale=outscale, cnn_sigmoid=cnn_sigmoid, ) if self.mlp_shapes: @@ -455,6 +459,8 @@ class MultiDecoder(nn.Module): act, norm, vector_dist, + outscale=outscale, + name="Decoder", ) self._image_dist = image_dist @@ -491,21 +497,18 @@ class ConvEncoder(nn.Module): input_shape, depth=32, act="SiLU", - norm="LayerNorm", + norm=True, kernel_size=4, minres=4, ): super(ConvEncoder, self).__init__() act = getattr(torch.nn, act) - norm = getattr(torch.nn, norm) h, w, input_ch = input_shape + stages = int(np.log2(h) - np.log2(minres)) + in_dim = input_ch + out_dim = depth layers = [] - for i in range(int(np.log2(h) - np.log2(minres))): - if i == 0: - in_dim = input_ch - else: - in_dim = 2 ** (i - 1) * depth - out_dim = 2**i * depth + for i in range(stages): layers.append( Conv2dSame( in_channels=in_dim, @@ -515,15 +518,19 @@ class ConvEncoder(nn.Module): bias=False, ) ) - layers.append(ChLayerNorm(out_dim)) + if norm: + layers.append(ChLayerNorm(out_dim)) layers.append(act()) + in_dim = out_dim + out_dim *= 2 h, w = h // 2, w // 2 - self.outdim = out_dim * h * w + self.outdim = out_dim // 2 * h * w self.layers = nn.Sequential(*layers) self.layers.apply(tools.weight_init) def forward(self, obs): + obs -= 0.5 # (batch, time, h, w, ch) -> (batch * time, h, w, ch) x = obs.reshape((-1,) + tuple(obs.shape[-3:])) # (batch * time, h, w, ch) -> (batch * time, ch, h, w) @@ -542,7 +549,7 @@ class ConvDecoder(nn.Module): shape=(3, 64, 64), depth=32, act=nn.ELU, - norm=nn.LayerNorm, + norm=True, kernel_size=4, minres=4, outscale=1.0, @@ -550,29 +557,27 @@ class ConvDecoder(nn.Module): ): super(ConvDecoder, self).__init__() act = getattr(torch.nn, act) - norm = getattr(torch.nn, norm) self._shape = shape self._cnn_sigmoid = cnn_sigmoid layer_num = int(np.log2(shape[1]) - np.log2(minres)) self._minres = minres - self._embed_size = minres**2 * depth * 2 ** (layer_num - 1) + out_ch = minres**2 * depth * 2 ** (layer_num - 1) + self._embed_size = out_ch - self._linear_layer = nn.Linear(feat_size, self._embed_size) - self._linear_layer.apply(tools.weight_init) - in_dim = self._embed_size // (minres**2) + self._linear_layer = nn.Linear(feat_size, out_ch) + self._linear_layer.apply(tools.uniform_weight_init(outscale)) + in_dim = out_ch // (minres**2) + out_dim = in_dim // 2 layers = [] h, w = minres, minres for i in range(layer_num): - out_dim = self._embed_size // (minres**2) // (2 ** (i + 1)) bias = False - initializer = tools.weight_init if i == layer_num - 1: out_dim = self._shape[0] act = False bias = True norm = False - initializer = tools.uniform_weight_init(outscale) if i != 0: in_dim = 2 ** (layer_num - (i - 1) - 2) * depth @@ -593,9 +598,11 @@ class ConvDecoder(nn.Module): layers.append(ChLayerNorm(out_dim)) if act: layers.append(act()) - [m.apply(initializer) for m in layers[-3:]] + in_dim = out_dim + out_dim //= 2 h, w = h * 2, w * 2 - + [m.apply(tools.weight_init) for m in layers[:-1]] + layers[-1].apply(tools.uniform_weight_init(outscale)) self.layers = nn.Sequential(*layers) def calc_same_pad(self, k, s, d): @@ -613,12 +620,14 @@ class ConvDecoder(nn.Module): # (batch, time, -1) -> (batch * time, ch, h, w) x = x.permute(0, 3, 1, 2) x = self.layers(x) - # (batch, time, -1) -> (batch * time, ch, h, w) necessary??? + # (batch, time, -1) -> (batch, time, ch, h, w) mean = x.reshape(features.shape[:-1] + self._shape) - # (batch * time, ch, h, w) -> (batch * time, h, w, ch) + # (batch, time, ch, h, w) -> (batch, time, h, w, ch) mean = mean.permute(0, 1, 3, 4, 2) if self._cnn_sigmoid: - mean = F.sigmoid(mean) - 0.5 + mean = F.sigmoid(mean) + else: + mean += 0.5 return mean diff --git a/tools.py b/tools.py index b12c52e..1aff067 100644 --- a/tools.py +++ b/tools.py @@ -920,7 +920,9 @@ def weight_init(m): denoms = (in_num + out_num) / 2.0 scale = 1.0 / denoms std = np.sqrt(scale) / 0.87962566103423978 - nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) + nn.init.trunc_normal_( + m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std + ) if hasattr(m.bias, "data"): m.bias.data.fill_(0.0) elif isinstance(m, nn.LayerNorm): @@ -940,6 +942,16 @@ def uniform_weight_init(given_scale): nn.init.uniform_(m.weight.data, a=-limit, b=limit) if hasattr(m.bias, "data"): m.bias.data.fill_(0.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + space = m.kernel_size[0] * m.kernel_size[1] + in_num = space * m.in_channels + out_num = space * m.out_channels + denoms = (in_num + out_num) / 2.0 + scale = given_scale / denoms + limit = np.sqrt(3 * scale) + nn.init.uniform_(m.weight.data, a=-limit, b=limit) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) elif isinstance(m, nn.LayerNorm): m.weight.data.fill_(1.0) if hasattr(m.bias, "data"):