modified weight initialization
This commit is contained in:
parent
4fe9b29ebe
commit
a9e85e8b7c
@ -179,7 +179,7 @@ class WorldModel(nn.Module):
|
||||
# this function is called during both rollout and training
|
||||
def preprocess(self, obs):
|
||||
obs = obs.copy()
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0
|
||||
if "discount" in obs:
|
||||
obs["discount"] *= self._config.discount
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
@ -209,8 +209,8 @@ class WorldModel(nn.Module):
|
||||
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
||||
# observed image is given until 5 steps
|
||||
model = torch.cat([recon[:, :5], openl], 1)
|
||||
truth = data["image"][:6] + 0.5
|
||||
model = model + 0.5
|
||||
truth = data["image"][:6]
|
||||
model = model
|
||||
error = (model - truth + 1.0) / 2.0
|
||||
|
||||
return torch.cat([truth, model, error], 2)
|
||||
|
81
networks.py
81
networks.py
@ -68,9 +68,8 @@ class RSSM(nn.Module):
|
||||
inp_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._inp_layers = nn.Sequential(*inp_layers)
|
||||
self._inp_layers.apply(tools.weight_init)
|
||||
|
||||
self._img_in_layers = nn.Sequential(*inp_layers)
|
||||
self._img_in_layers.apply(tools.weight_init)
|
||||
if cell == "gru":
|
||||
self._cell = GRUCell(self._hidden, self._deter)
|
||||
self._cell.apply(tools.weight_init)
|
||||
@ -106,15 +105,17 @@ class RSSM(nn.Module):
|
||||
self._obs_out_layers.apply(tools.weight_init)
|
||||
|
||||
if self._discrete:
|
||||
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
||||
self._ims_stat_layer.apply(tools.weight_init)
|
||||
self._imgs_stat_layer = nn.Linear(
|
||||
self._hidden, self._stoch * self._discrete
|
||||
)
|
||||
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||
else:
|
||||
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._ims_stat_layer.apply(tools.weight_init)
|
||||
self._imgs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||
|
||||
if self._initial == "learned":
|
||||
self.W = torch.nn.Parameter(
|
||||
@ -260,7 +261,7 @@ class RSSM(nn.Module):
|
||||
else:
|
||||
x = torch.cat([prev_stoch, prev_action], -1)
|
||||
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
|
||||
x = self._inp_layers(x)
|
||||
x = self._img_in_layers(x)
|
||||
for _ in range(self._rec_depth): # rec depth is not correctly implemented
|
||||
deter = prev_state["deter"]
|
||||
# (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter)
|
||||
@ -286,7 +287,7 @@ class RSSM(nn.Module):
|
||||
def _suff_stats_layer(self, name, x):
|
||||
if self._discrete:
|
||||
if name == "ims":
|
||||
x = self._ims_stat_layer(x)
|
||||
x = self._imgs_stat_layer(x)
|
||||
elif name == "obs":
|
||||
x = self._obs_stat_layer(x)
|
||||
else:
|
||||
@ -295,7 +296,7 @@ class RSSM(nn.Module):
|
||||
return {"logit": logit}
|
||||
else:
|
||||
if name == "ims":
|
||||
x = self._ims_stat_layer(x)
|
||||
x = self._imgs_stat_layer(x)
|
||||
elif name == "obs":
|
||||
x = self._obs_stat_layer(x)
|
||||
else:
|
||||
@ -386,6 +387,7 @@ class MultiEncoder(nn.Module):
|
||||
act,
|
||||
norm,
|
||||
symlog_inputs=symlog_inputs,
|
||||
name="Encoder",
|
||||
)
|
||||
self.outdim += mlp_units
|
||||
|
||||
@ -418,6 +420,7 @@ class MultiDecoder(nn.Module):
|
||||
cnn_sigmoid,
|
||||
image_dist,
|
||||
vector_dist,
|
||||
outscale,
|
||||
):
|
||||
super(MultiDecoder, self).__init__()
|
||||
excluded = ("is_first", "is_last", "is_terminal")
|
||||
@ -444,6 +447,7 @@ class MultiDecoder(nn.Module):
|
||||
norm,
|
||||
kernel_size,
|
||||
minres,
|
||||
outscale=outscale,
|
||||
cnn_sigmoid=cnn_sigmoid,
|
||||
)
|
||||
if self.mlp_shapes:
|
||||
@ -455,6 +459,8 @@ class MultiDecoder(nn.Module):
|
||||
act,
|
||||
norm,
|
||||
vector_dist,
|
||||
outscale=outscale,
|
||||
name="Decoder",
|
||||
)
|
||||
self._image_dist = image_dist
|
||||
|
||||
@ -491,21 +497,18 @@ class ConvEncoder(nn.Module):
|
||||
input_shape,
|
||||
depth=32,
|
||||
act="SiLU",
|
||||
norm="LayerNorm",
|
||||
norm=True,
|
||||
kernel_size=4,
|
||||
minres=4,
|
||||
):
|
||||
super(ConvEncoder, self).__init__()
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
h, w, input_ch = input_shape
|
||||
stages = int(np.log2(h) - np.log2(minres))
|
||||
in_dim = input_ch
|
||||
out_dim = depth
|
||||
layers = []
|
||||
for i in range(int(np.log2(h) - np.log2(minres))):
|
||||
if i == 0:
|
||||
in_dim = input_ch
|
||||
else:
|
||||
in_dim = 2 ** (i - 1) * depth
|
||||
out_dim = 2**i * depth
|
||||
for i in range(stages):
|
||||
layers.append(
|
||||
Conv2dSame(
|
||||
in_channels=in_dim,
|
||||
@ -515,15 +518,19 @@ class ConvEncoder(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
if norm:
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
layers.append(act())
|
||||
in_dim = out_dim
|
||||
out_dim *= 2
|
||||
h, w = h // 2, w // 2
|
||||
|
||||
self.outdim = out_dim * h * w
|
||||
self.outdim = out_dim // 2 * h * w
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
def forward(self, obs):
|
||||
obs -= 0.5
|
||||
# (batch, time, h, w, ch) -> (batch * time, h, w, ch)
|
||||
x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
|
||||
# (batch * time, h, w, ch) -> (batch * time, ch, h, w)
|
||||
@ -542,7 +549,7 @@ class ConvDecoder(nn.Module):
|
||||
shape=(3, 64, 64),
|
||||
depth=32,
|
||||
act=nn.ELU,
|
||||
norm=nn.LayerNorm,
|
||||
norm=True,
|
||||
kernel_size=4,
|
||||
minres=4,
|
||||
outscale=1.0,
|
||||
@ -550,29 +557,27 @@ class ConvDecoder(nn.Module):
|
||||
):
|
||||
super(ConvDecoder, self).__init__()
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._shape = shape
|
||||
self._cnn_sigmoid = cnn_sigmoid
|
||||
layer_num = int(np.log2(shape[1]) - np.log2(minres))
|
||||
self._minres = minres
|
||||
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
|
||||
out_ch = minres**2 * depth * 2 ** (layer_num - 1)
|
||||
self._embed_size = out_ch
|
||||
|
||||
self._linear_layer = nn.Linear(feat_size, self._embed_size)
|
||||
self._linear_layer.apply(tools.weight_init)
|
||||
in_dim = self._embed_size // (minres**2)
|
||||
self._linear_layer = nn.Linear(feat_size, out_ch)
|
||||
self._linear_layer.apply(tools.uniform_weight_init(outscale))
|
||||
in_dim = out_ch // (minres**2)
|
||||
out_dim = in_dim // 2
|
||||
|
||||
layers = []
|
||||
h, w = minres, minres
|
||||
for i in range(layer_num):
|
||||
out_dim = self._embed_size // (minres**2) // (2 ** (i + 1))
|
||||
bias = False
|
||||
initializer = tools.weight_init
|
||||
if i == layer_num - 1:
|
||||
out_dim = self._shape[0]
|
||||
act = False
|
||||
bias = True
|
||||
norm = False
|
||||
initializer = tools.uniform_weight_init(outscale)
|
||||
|
||||
if i != 0:
|
||||
in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
|
||||
@ -593,9 +598,11 @@ class ConvDecoder(nn.Module):
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
if act:
|
||||
layers.append(act())
|
||||
[m.apply(initializer) for m in layers[-3:]]
|
||||
in_dim = out_dim
|
||||
out_dim //= 2
|
||||
h, w = h * 2, w * 2
|
||||
|
||||
[m.apply(tools.weight_init) for m in layers[:-1]]
|
||||
layers[-1].apply(tools.uniform_weight_init(outscale))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def calc_same_pad(self, k, s, d):
|
||||
@ -613,12 +620,14 @@ class ConvDecoder(nn.Module):
|
||||
# (batch, time, -1) -> (batch * time, ch, h, w)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.layers(x)
|
||||
# (batch, time, -1) -> (batch * time, ch, h, w) necessary???
|
||||
# (batch, time, -1) -> (batch, time, ch, h, w)
|
||||
mean = x.reshape(features.shape[:-1] + self._shape)
|
||||
# (batch * time, ch, h, w) -> (batch * time, h, w, ch)
|
||||
# (batch, time, ch, h, w) -> (batch, time, h, w, ch)
|
||||
mean = mean.permute(0, 1, 3, 4, 2)
|
||||
if self._cnn_sigmoid:
|
||||
mean = F.sigmoid(mean) - 0.5
|
||||
mean = F.sigmoid(mean)
|
||||
else:
|
||||
mean += 0.5
|
||||
return mean
|
||||
|
||||
|
||||
|
14
tools.py
14
tools.py
@ -920,7 +920,9 @@ def weight_init(m):
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = 1.0 / denoms
|
||||
std = np.sqrt(scale) / 0.87962566103423978
|
||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
|
||||
nn.init.trunc_normal_(
|
||||
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
|
||||
)
|
||||
if hasattr(m.bias, "data"):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
@ -940,6 +942,16 @@ def uniform_weight_init(given_scale):
|
||||
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
||||
if hasattr(m.bias, "data"):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||
space = m.kernel_size[0] * m.kernel_size[1]
|
||||
in_num = space * m.in_channels
|
||||
out_num = space * m.out_channels
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = given_scale / denoms
|
||||
limit = np.sqrt(3 * scale)
|
||||
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
||||
if hasattr(m.bias, "data"):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.weight.data.fill_(1.0)
|
||||
if hasattr(m.bias, "data"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user