From 0faa10ff46c01201317e16ab1cafa9de87cfe76f Mon Sep 17 00:00:00 2001 From: NM512 Date: Sun, 21 May 2023 22:00:59 +0900 Subject: [PATCH] expanded the supported image sizes --- configs.yaml | 4 +-- networks.py | 92 +++++++++++++++++++++++++++++----------------------- 2 files changed, 53 insertions(+), 43 deletions(-) diff --git a/configs.yaml b/configs.yaml index 21718c0..332a977 100644 --- a/configs.yaml +++ b/configs.yaml @@ -52,9 +52,9 @@ defaults: act: 'SiLU' norm: 'LayerNorm' encoder: - {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, symlog_inputs: True} + {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True} decoder: - {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse,} + {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse} value_head: 'symlog_disc' reward_head: 'symlog_disc' dyn_scale: '0.5' diff --git a/networks.py b/networks.py index d122ec2..9c58faf 100644 --- a/networks.py +++ b/networks.py @@ -337,7 +337,8 @@ class MultiEncoder(nn.Module): act, norm, cnn_depth, - cnn_kernels, + kernel_size, + minres, mlp_layers, mlp_units, symlog_inputs, @@ -359,7 +360,10 @@ class MultiEncoder(nn.Module): self.outdim = 0 if self.cnn_shapes: input_ch = sum([v[-1] for v in self.cnn_shapes.values()]) - self._cnn = ConvEncoder(input_ch, cnn_depth, act, norm, cnn_kernels) + input_shape = tuple(self.cnn_shapes.values())[0][:2] + (input_ch,) + self._cnn = ConvEncoder( + input_shape, cnn_depth, act, norm, kernel_size, minres + ) self.outdim += self._cnn.outdim if self.mlp_shapes: input_size = sum([sum(v) for v in self.mlp_shapes.values()]) @@ -396,7 +400,8 @@ class MultiDecoder(nn.Module): act, norm, cnn_depth, - cnn_kernels, + kernel_size, + minres, mlp_layers, mlp_units, cnn_sigmoid, @@ -426,7 +431,8 @@ class MultiDecoder(nn.Module): cnn_depth, act, norm, - cnn_kernels, + kernel_size, + minres, cnn_sigmoid=cnn_sigmoid, ) if self.mlp_shapes: @@ -470,35 +476,39 @@ class MultiDecoder(nn.Module): class ConvEncoder(nn.Module): def __init__( - self, input_ch, depth=32, act="SiLU", norm="LayerNorm", kernels=(3, 3, 3, 3) + self, + input_shape, + depth=32, + act="SiLU", + norm="LayerNorm", + kernel_size=4, + minres=4, ): super(ConvEncoder, self).__init__() act = getattr(torch.nn, act) norm = getattr(torch.nn, norm) - self._depth = depth - self._kernels = kernels - h, w = 64, 64 + h, w, input_ch = input_shape layers = [] - for i, kernel in enumerate(self._kernels): + for i in range(int(np.log2(h) - np.log2(minres))): if i == 0: - inp_dim = input_ch + in_dim = input_ch else: - inp_dim = 2 ** (i - 1) * self._depth - depth = 2**i * self._depth + in_dim = 2 ** (i - 1) * depth + out_dim = 2**i * depth layers.append( Conv2dSame( - in_channels=inp_dim, - out_channels=depth, - kernel_size=(kernel, kernel), - stride=(2, 2), + in_channels=in_dim, + out_channels=out_dim, + kernel_size=kernel_size, + stride=2, bias=False, ) ) - layers.append(ChLayerNorm(depth)) + layers.append(ChLayerNorm(out_dim)) layers.append(act()) h, w = h // 2, w // 2 - self.outdim = depth * h * w + self.outdim = out_dim * h * w self.layers = nn.Sequential(*layers) self.layers.apply(tools.weight_init) @@ -517,52 +527,50 @@ class ConvEncoder(nn.Module): class ConvDecoder(nn.Module): def __init__( self, - inp_depth, + feat_size, shape=(3, 64, 64), depth=32, act=nn.ELU, norm=nn.LayerNorm, - kernels=(3, 3, 3, 3), + kernel_size=4, + minres=4, outscale=1.0, cnn_sigmoid=False, ): super(ConvDecoder, self).__init__() - self._inp_depth = inp_depth act = getattr(torch.nn, act) norm = getattr(torch.nn, norm) - self._depth = depth self._shape = shape - self._kernels = kernels self._cnn_sigmoid = cnn_sigmoid - self._embed_size = ( - (64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1) - ) + layer_num = int(np.log2(shape[1]) - np.log2(minres)) + self._minres = minres + self._embed_size = minres**2 * depth * 2 ** (layer_num - 1) - self._linear_layer = nn.Linear(inp_depth, self._embed_size) - inp_dim = self._embed_size // 16 + self._linear_layer = nn.Linear(feat_size, self._embed_size) + in_dim = self._embed_size // (minres**2) layers = [] - h, w = 4, 4 - for i, kernel in enumerate(self._kernels): - depth = self._embed_size // 16 // (2 ** (i + 1)) + h, w = minres, minres + for i in range(layer_num): + out_dim = self._embed_size // (minres**2) // (2 ** (i + 1)) bias = False initializer = tools.weight_init - if i == len(self._kernels) - 1: - depth = self._shape[0] + if i == layer_num - 1: + out_dim = self._shape[0] act = False bias = True norm = False initializer = tools.uniform_weight_init(outscale) if i != 0: - inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth - pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1) - pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1) + in_dim = 2 ** (layer_num - (i - 1) - 2) * depth + pad_h, outpad_h = self.calc_same_pad(k=kernel_size, s=2, d=1) + pad_w, outpad_w = self.calc_same_pad(k=kernel_size, s=2, d=1) layers.append( nn.ConvTranspose2d( - inp_dim, - depth, - kernel, + in_dim, + out_dim, + kernel_size, 2, padding=(pad_h, pad_w), output_padding=(outpad_h, outpad_w), @@ -570,7 +578,7 @@ class ConvDecoder(nn.Module): ) ) if norm: - layers.append(ChLayerNorm(depth)) + layers.append(ChLayerNorm(out_dim)) if act: layers.append(act()) [m.apply(initializer) for m in layers[-3:]] @@ -587,7 +595,9 @@ class ConvDecoder(nn.Module): def forward(self, features, dtype=None): x = self._linear_layer(features) # (batch, time, -1) -> (batch * time, h, w, ch) - x = x.reshape([-1, 4, 4, self._embed_size // 16]) + x = x.reshape( + [-1, self._minres, self._minres, self._embed_size // self._minres**2] + ) # (batch, time, -1) -> (batch * time, ch, h, w) x = x.permute(0, 3, 1, 2) x = self.layers(x)