modified weight initialization

This commit is contained in:
NM512 2024-01-05 10:46:54 +09:00
parent 4fe9b29ebe
commit a9e85e8b7c
3 changed files with 61 additions and 40 deletions

View File

@ -179,7 +179,7 @@ class WorldModel(nn.Module):
# this function is called during both rollout and training
def preprocess(self, obs):
obs = obs.copy()
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
obs["image"] = torch.Tensor(obs["image"]) / 255.0
if "discount" in obs:
obs["discount"] *= self._config.discount
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
@ -209,8 +209,8 @@ class WorldModel(nn.Module):
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
# observed image is given until 5 steps
model = torch.cat([recon[:, :5], openl], 1)
truth = data["image"][:6] + 0.5
model = model + 0.5
truth = data["image"][:6]
model = model
error = (model - truth + 1.0) / 2.0
return torch.cat([truth, model, error], 2)

View File

@ -68,9 +68,8 @@ class RSSM(nn.Module):
inp_layers.append(act())
if i == 0:
inp_dim = self._hidden
self._inp_layers = nn.Sequential(*inp_layers)
self._inp_layers.apply(tools.weight_init)
self._img_in_layers = nn.Sequential(*inp_layers)
self._img_in_layers.apply(tools.weight_init)
if cell == "gru":
self._cell = GRUCell(self._hidden, self._deter)
self._cell.apply(tools.weight_init)
@ -106,15 +105,17 @@ class RSSM(nn.Module):
self._obs_out_layers.apply(tools.weight_init)
if self._discrete:
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
self._ims_stat_layer.apply(tools.weight_init)
self._imgs_stat_layer = nn.Linear(
self._hidden, self._stoch * self._discrete
)
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
self._obs_stat_layer.apply(tools.weight_init)
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
else:
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._ims_stat_layer.apply(tools.weight_init)
self._imgs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._obs_stat_layer.apply(tools.weight_init)
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
if self._initial == "learned":
self.W = torch.nn.Parameter(
@ -260,7 +261,7 @@ class RSSM(nn.Module):
else:
x = torch.cat([prev_stoch, prev_action], -1)
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
x = self._inp_layers(x)
x = self._img_in_layers(x)
for _ in range(self._rec_depth): # rec depth is not correctly implemented
deter = prev_state["deter"]
# (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter)
@ -286,7 +287,7 @@ class RSSM(nn.Module):
def _suff_stats_layer(self, name, x):
if self._discrete:
if name == "ims":
x = self._ims_stat_layer(x)
x = self._imgs_stat_layer(x)
elif name == "obs":
x = self._obs_stat_layer(x)
else:
@ -295,7 +296,7 @@ class RSSM(nn.Module):
return {"logit": logit}
else:
if name == "ims":
x = self._ims_stat_layer(x)
x = self._imgs_stat_layer(x)
elif name == "obs":
x = self._obs_stat_layer(x)
else:
@ -386,6 +387,7 @@ class MultiEncoder(nn.Module):
act,
norm,
symlog_inputs=symlog_inputs,
name="Encoder",
)
self.outdim += mlp_units
@ -418,6 +420,7 @@ class MultiDecoder(nn.Module):
cnn_sigmoid,
image_dist,
vector_dist,
outscale,
):
super(MultiDecoder, self).__init__()
excluded = ("is_first", "is_last", "is_terminal")
@ -444,6 +447,7 @@ class MultiDecoder(nn.Module):
norm,
kernel_size,
minres,
outscale=outscale,
cnn_sigmoid=cnn_sigmoid,
)
if self.mlp_shapes:
@ -455,6 +459,8 @@ class MultiDecoder(nn.Module):
act,
norm,
vector_dist,
outscale=outscale,
name="Decoder",
)
self._image_dist = image_dist
@ -491,21 +497,18 @@ class ConvEncoder(nn.Module):
input_shape,
depth=32,
act="SiLU",
norm="LayerNorm",
norm=True,
kernel_size=4,
minres=4,
):
super(ConvEncoder, self).__init__()
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
h, w, input_ch = input_shape
stages = int(np.log2(h) - np.log2(minres))
in_dim = input_ch
out_dim = depth
layers = []
for i in range(int(np.log2(h) - np.log2(minres))):
if i == 0:
in_dim = input_ch
else:
in_dim = 2 ** (i - 1) * depth
out_dim = 2**i * depth
for i in range(stages):
layers.append(
Conv2dSame(
in_channels=in_dim,
@ -515,15 +518,19 @@ class ConvEncoder(nn.Module):
bias=False,
)
)
layers.append(ChLayerNorm(out_dim))
if norm:
layers.append(ChLayerNorm(out_dim))
layers.append(act())
in_dim = out_dim
out_dim *= 2
h, w = h // 2, w // 2
self.outdim = out_dim * h * w
self.outdim = out_dim // 2 * h * w
self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init)
def forward(self, obs):
obs -= 0.5
# (batch, time, h, w, ch) -> (batch * time, h, w, ch)
x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
# (batch * time, h, w, ch) -> (batch * time, ch, h, w)
@ -542,7 +549,7 @@ class ConvDecoder(nn.Module):
shape=(3, 64, 64),
depth=32,
act=nn.ELU,
norm=nn.LayerNorm,
norm=True,
kernel_size=4,
minres=4,
outscale=1.0,
@ -550,29 +557,27 @@ class ConvDecoder(nn.Module):
):
super(ConvDecoder, self).__init__()
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._shape = shape
self._cnn_sigmoid = cnn_sigmoid
layer_num = int(np.log2(shape[1]) - np.log2(minres))
self._minres = minres
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
out_ch = minres**2 * depth * 2 ** (layer_num - 1)
self._embed_size = out_ch
self._linear_layer = nn.Linear(feat_size, self._embed_size)
self._linear_layer.apply(tools.weight_init)
in_dim = self._embed_size // (minres**2)
self._linear_layer = nn.Linear(feat_size, out_ch)
self._linear_layer.apply(tools.uniform_weight_init(outscale))
in_dim = out_ch // (minres**2)
out_dim = in_dim // 2
layers = []
h, w = minres, minres
for i in range(layer_num):
out_dim = self._embed_size // (minres**2) // (2 ** (i + 1))
bias = False
initializer = tools.weight_init
if i == layer_num - 1:
out_dim = self._shape[0]
act = False
bias = True
norm = False
initializer = tools.uniform_weight_init(outscale)
if i != 0:
in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
@ -593,9 +598,11 @@ class ConvDecoder(nn.Module):
layers.append(ChLayerNorm(out_dim))
if act:
layers.append(act())
[m.apply(initializer) for m in layers[-3:]]
in_dim = out_dim
out_dim //= 2
h, w = h * 2, w * 2
[m.apply(tools.weight_init) for m in layers[:-1]]
layers[-1].apply(tools.uniform_weight_init(outscale))
self.layers = nn.Sequential(*layers)
def calc_same_pad(self, k, s, d):
@ -613,12 +620,14 @@ class ConvDecoder(nn.Module):
# (batch, time, -1) -> (batch * time, ch, h, w)
x = x.permute(0, 3, 1, 2)
x = self.layers(x)
# (batch, time, -1) -> (batch * time, ch, h, w) necessary???
# (batch, time, -1) -> (batch, time, ch, h, w)
mean = x.reshape(features.shape[:-1] + self._shape)
# (batch * time, ch, h, w) -> (batch * time, h, w, ch)
# (batch, time, ch, h, w) -> (batch, time, h, w, ch)
mean = mean.permute(0, 1, 3, 4, 2)
if self._cnn_sigmoid:
mean = F.sigmoid(mean) - 0.5
mean = F.sigmoid(mean)
else:
mean += 0.5
return mean

View File

@ -920,7 +920,9 @@ def weight_init(m):
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
nn.init.trunc_normal_(
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
@ -940,6 +942,16 @@ def uniform_weight_init(given_scale):
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
space = m.kernel_size[0] * m.kernel_size[1]
in_num = space * m.in_channels
out_num = space * m.out_channels
denoms = (in_num + out_num) / 2.0
scale = given_scale / denoms
limit = np.sqrt(3 * scale)
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, "data"):