modified weight initialization
This commit is contained in:
parent
4fe9b29ebe
commit
a9e85e8b7c
@ -179,7 +179,7 @@ class WorldModel(nn.Module):
|
|||||||
# this function is called during both rollout and training
|
# this function is called during both rollout and training
|
||||||
def preprocess(self, obs):
|
def preprocess(self, obs):
|
||||||
obs = obs.copy()
|
obs = obs.copy()
|
||||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
|
obs["image"] = torch.Tensor(obs["image"]) / 255.0
|
||||||
if "discount" in obs:
|
if "discount" in obs:
|
||||||
obs["discount"] *= self._config.discount
|
obs["discount"] *= self._config.discount
|
||||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||||
@ -209,8 +209,8 @@ class WorldModel(nn.Module):
|
|||||||
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
||||||
# observed image is given until 5 steps
|
# observed image is given until 5 steps
|
||||||
model = torch.cat([recon[:, :5], openl], 1)
|
model = torch.cat([recon[:, :5], openl], 1)
|
||||||
truth = data["image"][:6] + 0.5
|
truth = data["image"][:6]
|
||||||
model = model + 0.5
|
model = model
|
||||||
error = (model - truth + 1.0) / 2.0
|
error = (model - truth + 1.0) / 2.0
|
||||||
|
|
||||||
return torch.cat([truth, model, error], 2)
|
return torch.cat([truth, model, error], 2)
|
||||||
|
81
networks.py
81
networks.py
@ -68,9 +68,8 @@ class RSSM(nn.Module):
|
|||||||
inp_layers.append(act())
|
inp_layers.append(act())
|
||||||
if i == 0:
|
if i == 0:
|
||||||
inp_dim = self._hidden
|
inp_dim = self._hidden
|
||||||
self._inp_layers = nn.Sequential(*inp_layers)
|
self._img_in_layers = nn.Sequential(*inp_layers)
|
||||||
self._inp_layers.apply(tools.weight_init)
|
self._img_in_layers.apply(tools.weight_init)
|
||||||
|
|
||||||
if cell == "gru":
|
if cell == "gru":
|
||||||
self._cell = GRUCell(self._hidden, self._deter)
|
self._cell = GRUCell(self._hidden, self._deter)
|
||||||
self._cell.apply(tools.weight_init)
|
self._cell.apply(tools.weight_init)
|
||||||
@ -106,15 +105,17 @@ class RSSM(nn.Module):
|
|||||||
self._obs_out_layers.apply(tools.weight_init)
|
self._obs_out_layers.apply(tools.weight_init)
|
||||||
|
|
||||||
if self._discrete:
|
if self._discrete:
|
||||||
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
self._imgs_stat_layer = nn.Linear(
|
||||||
self._ims_stat_layer.apply(tools.weight_init)
|
self._hidden, self._stoch * self._discrete
|
||||||
|
)
|
||||||
|
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||||
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
||||||
self._obs_stat_layer.apply(tools.weight_init)
|
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||||
else:
|
else:
|
||||||
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
self._imgs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||||
self._ims_stat_layer.apply(tools.weight_init)
|
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||||
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||||
self._obs_stat_layer.apply(tools.weight_init)
|
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||||
|
|
||||||
if self._initial == "learned":
|
if self._initial == "learned":
|
||||||
self.W = torch.nn.Parameter(
|
self.W = torch.nn.Parameter(
|
||||||
@ -260,7 +261,7 @@ class RSSM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x = torch.cat([prev_stoch, prev_action], -1)
|
x = torch.cat([prev_stoch, prev_action], -1)
|
||||||
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
|
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
|
||||||
x = self._inp_layers(x)
|
x = self._img_in_layers(x)
|
||||||
for _ in range(self._rec_depth): # rec depth is not correctly implemented
|
for _ in range(self._rec_depth): # rec depth is not correctly implemented
|
||||||
deter = prev_state["deter"]
|
deter = prev_state["deter"]
|
||||||
# (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter)
|
# (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter)
|
||||||
@ -286,7 +287,7 @@ class RSSM(nn.Module):
|
|||||||
def _suff_stats_layer(self, name, x):
|
def _suff_stats_layer(self, name, x):
|
||||||
if self._discrete:
|
if self._discrete:
|
||||||
if name == "ims":
|
if name == "ims":
|
||||||
x = self._ims_stat_layer(x)
|
x = self._imgs_stat_layer(x)
|
||||||
elif name == "obs":
|
elif name == "obs":
|
||||||
x = self._obs_stat_layer(x)
|
x = self._obs_stat_layer(x)
|
||||||
else:
|
else:
|
||||||
@ -295,7 +296,7 @@ class RSSM(nn.Module):
|
|||||||
return {"logit": logit}
|
return {"logit": logit}
|
||||||
else:
|
else:
|
||||||
if name == "ims":
|
if name == "ims":
|
||||||
x = self._ims_stat_layer(x)
|
x = self._imgs_stat_layer(x)
|
||||||
elif name == "obs":
|
elif name == "obs":
|
||||||
x = self._obs_stat_layer(x)
|
x = self._obs_stat_layer(x)
|
||||||
else:
|
else:
|
||||||
@ -386,6 +387,7 @@ class MultiEncoder(nn.Module):
|
|||||||
act,
|
act,
|
||||||
norm,
|
norm,
|
||||||
symlog_inputs=symlog_inputs,
|
symlog_inputs=symlog_inputs,
|
||||||
|
name="Encoder",
|
||||||
)
|
)
|
||||||
self.outdim += mlp_units
|
self.outdim += mlp_units
|
||||||
|
|
||||||
@ -418,6 +420,7 @@ class MultiDecoder(nn.Module):
|
|||||||
cnn_sigmoid,
|
cnn_sigmoid,
|
||||||
image_dist,
|
image_dist,
|
||||||
vector_dist,
|
vector_dist,
|
||||||
|
outscale,
|
||||||
):
|
):
|
||||||
super(MultiDecoder, self).__init__()
|
super(MultiDecoder, self).__init__()
|
||||||
excluded = ("is_first", "is_last", "is_terminal")
|
excluded = ("is_first", "is_last", "is_terminal")
|
||||||
@ -444,6 +447,7 @@ class MultiDecoder(nn.Module):
|
|||||||
norm,
|
norm,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
minres,
|
minres,
|
||||||
|
outscale=outscale,
|
||||||
cnn_sigmoid=cnn_sigmoid,
|
cnn_sigmoid=cnn_sigmoid,
|
||||||
)
|
)
|
||||||
if self.mlp_shapes:
|
if self.mlp_shapes:
|
||||||
@ -455,6 +459,8 @@ class MultiDecoder(nn.Module):
|
|||||||
act,
|
act,
|
||||||
norm,
|
norm,
|
||||||
vector_dist,
|
vector_dist,
|
||||||
|
outscale=outscale,
|
||||||
|
name="Decoder",
|
||||||
)
|
)
|
||||||
self._image_dist = image_dist
|
self._image_dist = image_dist
|
||||||
|
|
||||||
@ -491,21 +497,18 @@ class ConvEncoder(nn.Module):
|
|||||||
input_shape,
|
input_shape,
|
||||||
depth=32,
|
depth=32,
|
||||||
act="SiLU",
|
act="SiLU",
|
||||||
norm="LayerNorm",
|
norm=True,
|
||||||
kernel_size=4,
|
kernel_size=4,
|
||||||
minres=4,
|
minres=4,
|
||||||
):
|
):
|
||||||
super(ConvEncoder, self).__init__()
|
super(ConvEncoder, self).__init__()
|
||||||
act = getattr(torch.nn, act)
|
act = getattr(torch.nn, act)
|
||||||
norm = getattr(torch.nn, norm)
|
|
||||||
h, w, input_ch = input_shape
|
h, w, input_ch = input_shape
|
||||||
|
stages = int(np.log2(h) - np.log2(minres))
|
||||||
|
in_dim = input_ch
|
||||||
|
out_dim = depth
|
||||||
layers = []
|
layers = []
|
||||||
for i in range(int(np.log2(h) - np.log2(minres))):
|
for i in range(stages):
|
||||||
if i == 0:
|
|
||||||
in_dim = input_ch
|
|
||||||
else:
|
|
||||||
in_dim = 2 ** (i - 1) * depth
|
|
||||||
out_dim = 2**i * depth
|
|
||||||
layers.append(
|
layers.append(
|
||||||
Conv2dSame(
|
Conv2dSame(
|
||||||
in_channels=in_dim,
|
in_channels=in_dim,
|
||||||
@ -515,15 +518,19 @@ class ConvEncoder(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layers.append(ChLayerNorm(out_dim))
|
if norm:
|
||||||
|
layers.append(ChLayerNorm(out_dim))
|
||||||
layers.append(act())
|
layers.append(act())
|
||||||
|
in_dim = out_dim
|
||||||
|
out_dim *= 2
|
||||||
h, w = h // 2, w // 2
|
h, w = h // 2, w // 2
|
||||||
|
|
||||||
self.outdim = out_dim * h * w
|
self.outdim = out_dim // 2 * h * w
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
self.layers.apply(tools.weight_init)
|
self.layers.apply(tools.weight_init)
|
||||||
|
|
||||||
def forward(self, obs):
|
def forward(self, obs):
|
||||||
|
obs -= 0.5
|
||||||
# (batch, time, h, w, ch) -> (batch * time, h, w, ch)
|
# (batch, time, h, w, ch) -> (batch * time, h, w, ch)
|
||||||
x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
|
x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
|
||||||
# (batch * time, h, w, ch) -> (batch * time, ch, h, w)
|
# (batch * time, h, w, ch) -> (batch * time, ch, h, w)
|
||||||
@ -542,7 +549,7 @@ class ConvDecoder(nn.Module):
|
|||||||
shape=(3, 64, 64),
|
shape=(3, 64, 64),
|
||||||
depth=32,
|
depth=32,
|
||||||
act=nn.ELU,
|
act=nn.ELU,
|
||||||
norm=nn.LayerNorm,
|
norm=True,
|
||||||
kernel_size=4,
|
kernel_size=4,
|
||||||
minres=4,
|
minres=4,
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
@ -550,29 +557,27 @@ class ConvDecoder(nn.Module):
|
|||||||
):
|
):
|
||||||
super(ConvDecoder, self).__init__()
|
super(ConvDecoder, self).__init__()
|
||||||
act = getattr(torch.nn, act)
|
act = getattr(torch.nn, act)
|
||||||
norm = getattr(torch.nn, norm)
|
|
||||||
self._shape = shape
|
self._shape = shape
|
||||||
self._cnn_sigmoid = cnn_sigmoid
|
self._cnn_sigmoid = cnn_sigmoid
|
||||||
layer_num = int(np.log2(shape[1]) - np.log2(minres))
|
layer_num = int(np.log2(shape[1]) - np.log2(minres))
|
||||||
self._minres = minres
|
self._minres = minres
|
||||||
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
|
out_ch = minres**2 * depth * 2 ** (layer_num - 1)
|
||||||
|
self._embed_size = out_ch
|
||||||
|
|
||||||
self._linear_layer = nn.Linear(feat_size, self._embed_size)
|
self._linear_layer = nn.Linear(feat_size, out_ch)
|
||||||
self._linear_layer.apply(tools.weight_init)
|
self._linear_layer.apply(tools.uniform_weight_init(outscale))
|
||||||
in_dim = self._embed_size // (minres**2)
|
in_dim = out_ch // (minres**2)
|
||||||
|
out_dim = in_dim // 2
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
h, w = minres, minres
|
h, w = minres, minres
|
||||||
for i in range(layer_num):
|
for i in range(layer_num):
|
||||||
out_dim = self._embed_size // (minres**2) // (2 ** (i + 1))
|
|
||||||
bias = False
|
bias = False
|
||||||
initializer = tools.weight_init
|
|
||||||
if i == layer_num - 1:
|
if i == layer_num - 1:
|
||||||
out_dim = self._shape[0]
|
out_dim = self._shape[0]
|
||||||
act = False
|
act = False
|
||||||
bias = True
|
bias = True
|
||||||
norm = False
|
norm = False
|
||||||
initializer = tools.uniform_weight_init(outscale)
|
|
||||||
|
|
||||||
if i != 0:
|
if i != 0:
|
||||||
in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
|
in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
|
||||||
@ -593,9 +598,11 @@ class ConvDecoder(nn.Module):
|
|||||||
layers.append(ChLayerNorm(out_dim))
|
layers.append(ChLayerNorm(out_dim))
|
||||||
if act:
|
if act:
|
||||||
layers.append(act())
|
layers.append(act())
|
||||||
[m.apply(initializer) for m in layers[-3:]]
|
in_dim = out_dim
|
||||||
|
out_dim //= 2
|
||||||
h, w = h * 2, w * 2
|
h, w = h * 2, w * 2
|
||||||
|
[m.apply(tools.weight_init) for m in layers[:-1]]
|
||||||
|
layers[-1].apply(tools.uniform_weight_init(outscale))
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
def calc_same_pad(self, k, s, d):
|
def calc_same_pad(self, k, s, d):
|
||||||
@ -613,12 +620,14 @@ class ConvDecoder(nn.Module):
|
|||||||
# (batch, time, -1) -> (batch * time, ch, h, w)
|
# (batch, time, -1) -> (batch * time, ch, h, w)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
# (batch, time, -1) -> (batch * time, ch, h, w) necessary???
|
# (batch, time, -1) -> (batch, time, ch, h, w)
|
||||||
mean = x.reshape(features.shape[:-1] + self._shape)
|
mean = x.reshape(features.shape[:-1] + self._shape)
|
||||||
# (batch * time, ch, h, w) -> (batch * time, h, w, ch)
|
# (batch, time, ch, h, w) -> (batch, time, h, w, ch)
|
||||||
mean = mean.permute(0, 1, 3, 4, 2)
|
mean = mean.permute(0, 1, 3, 4, 2)
|
||||||
if self._cnn_sigmoid:
|
if self._cnn_sigmoid:
|
||||||
mean = F.sigmoid(mean) - 0.5
|
mean = F.sigmoid(mean)
|
||||||
|
else:
|
||||||
|
mean += 0.5
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
14
tools.py
14
tools.py
@ -920,7 +920,9 @@ def weight_init(m):
|
|||||||
denoms = (in_num + out_num) / 2.0
|
denoms = (in_num + out_num) / 2.0
|
||||||
scale = 1.0 / denoms
|
scale = 1.0 / denoms
|
||||||
std = np.sqrt(scale) / 0.87962566103423978
|
std = np.sqrt(scale) / 0.87962566103423978
|
||||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
|
nn.init.trunc_normal_(
|
||||||
|
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
|
||||||
|
)
|
||||||
if hasattr(m.bias, "data"):
|
if hasattr(m.bias, "data"):
|
||||||
m.bias.data.fill_(0.0)
|
m.bias.data.fill_(0.0)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
@ -940,6 +942,16 @@ def uniform_weight_init(given_scale):
|
|||||||
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
||||||
if hasattr(m.bias, "data"):
|
if hasattr(m.bias, "data"):
|
||||||
m.bias.data.fill_(0.0)
|
m.bias.data.fill_(0.0)
|
||||||
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||||
|
space = m.kernel_size[0] * m.kernel_size[1]
|
||||||
|
in_num = space * m.in_channels
|
||||||
|
out_num = space * m.out_channels
|
||||||
|
denoms = (in_num + out_num) / 2.0
|
||||||
|
scale = given_scale / denoms
|
||||||
|
limit = np.sqrt(3 * scale)
|
||||||
|
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
||||||
|
if hasattr(m.bias, "data"):
|
||||||
|
m.bias.data.fill_(0.0)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
m.weight.data.fill_(1.0)
|
m.weight.data.fill_(1.0)
|
||||||
if hasattr(m.bias, "data"):
|
if hasattr(m.bias, "data"):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user