expanded the supported image sizes

This commit is contained in:
NM512 2023-05-21 22:00:59 +09:00
parent 02c3d45fcf
commit 0faa10ff46
2 changed files with 53 additions and 43 deletions

View File

@ -52,9 +52,9 @@ defaults:
act: 'SiLU'
norm: 'LayerNorm'
encoder:
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
decoder:
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse,}
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse}
value_head: 'symlog_disc'
reward_head: 'symlog_disc'
dyn_scale: '0.5'

View File

@ -337,7 +337,8 @@ class MultiEncoder(nn.Module):
act,
norm,
cnn_depth,
cnn_kernels,
kernel_size,
minres,
mlp_layers,
mlp_units,
symlog_inputs,
@ -359,7 +360,10 @@ class MultiEncoder(nn.Module):
self.outdim = 0
if self.cnn_shapes:
input_ch = sum([v[-1] for v in self.cnn_shapes.values()])
self._cnn = ConvEncoder(input_ch, cnn_depth, act, norm, cnn_kernels)
input_shape = tuple(self.cnn_shapes.values())[0][:2] + (input_ch,)
self._cnn = ConvEncoder(
input_shape, cnn_depth, act, norm, kernel_size, minres
)
self.outdim += self._cnn.outdim
if self.mlp_shapes:
input_size = sum([sum(v) for v in self.mlp_shapes.values()])
@ -396,7 +400,8 @@ class MultiDecoder(nn.Module):
act,
norm,
cnn_depth,
cnn_kernels,
kernel_size,
minres,
mlp_layers,
mlp_units,
cnn_sigmoid,
@ -426,7 +431,8 @@ class MultiDecoder(nn.Module):
cnn_depth,
act,
norm,
cnn_kernels,
kernel_size,
minres,
cnn_sigmoid=cnn_sigmoid,
)
if self.mlp_shapes:
@ -470,35 +476,39 @@ class MultiDecoder(nn.Module):
class ConvEncoder(nn.Module):
def __init__(
self, input_ch, depth=32, act="SiLU", norm="LayerNorm", kernels=(3, 3, 3, 3)
self,
input_shape,
depth=32,
act="SiLU",
norm="LayerNorm",
kernel_size=4,
minres=4,
):
super(ConvEncoder, self).__init__()
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._depth = depth
self._kernels = kernels
h, w = 64, 64
h, w, input_ch = input_shape
layers = []
for i, kernel in enumerate(self._kernels):
for i in range(int(np.log2(h) - np.log2(minres))):
if i == 0:
inp_dim = input_ch
in_dim = input_ch
else:
inp_dim = 2 ** (i - 1) * self._depth
depth = 2**i * self._depth
in_dim = 2 ** (i - 1) * depth
out_dim = 2**i * depth
layers.append(
Conv2dSame(
in_channels=inp_dim,
out_channels=depth,
kernel_size=(kernel, kernel),
stride=(2, 2),
in_channels=in_dim,
out_channels=out_dim,
kernel_size=kernel_size,
stride=2,
bias=False,
)
)
layers.append(ChLayerNorm(depth))
layers.append(ChLayerNorm(out_dim))
layers.append(act())
h, w = h // 2, w // 2
self.outdim = depth * h * w
self.outdim = out_dim * h * w
self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init)
@ -517,52 +527,50 @@ class ConvEncoder(nn.Module):
class ConvDecoder(nn.Module):
def __init__(
self,
inp_depth,
feat_size,
shape=(3, 64, 64),
depth=32,
act=nn.ELU,
norm=nn.LayerNorm,
kernels=(3, 3, 3, 3),
kernel_size=4,
minres=4,
outscale=1.0,
cnn_sigmoid=False,
):
super(ConvDecoder, self).__init__()
self._inp_depth = inp_depth
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._depth = depth
self._shape = shape
self._kernels = kernels
self._cnn_sigmoid = cnn_sigmoid
self._embed_size = (
(64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)
)
layer_num = int(np.log2(shape[1]) - np.log2(minres))
self._minres = minres
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
self._linear_layer = nn.Linear(inp_depth, self._embed_size)
inp_dim = self._embed_size // 16
self._linear_layer = nn.Linear(feat_size, self._embed_size)
in_dim = self._embed_size // (minres**2)
layers = []
h, w = 4, 4
for i, kernel in enumerate(self._kernels):
depth = self._embed_size // 16 // (2 ** (i + 1))
h, w = minres, minres
for i in range(layer_num):
out_dim = self._embed_size // (minres**2) // (2 ** (i + 1))
bias = False
initializer = tools.weight_init
if i == len(self._kernels) - 1:
depth = self._shape[0]
if i == layer_num - 1:
out_dim = self._shape[0]
act = False
bias = True
norm = False
initializer = tools.uniform_weight_init(outscale)
if i != 0:
inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth
pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1)
pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1)
in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
pad_h, outpad_h = self.calc_same_pad(k=kernel_size, s=2, d=1)
pad_w, outpad_w = self.calc_same_pad(k=kernel_size, s=2, d=1)
layers.append(
nn.ConvTranspose2d(
inp_dim,
depth,
kernel,
in_dim,
out_dim,
kernel_size,
2,
padding=(pad_h, pad_w),
output_padding=(outpad_h, outpad_w),
@ -570,7 +578,7 @@ class ConvDecoder(nn.Module):
)
)
if norm:
layers.append(ChLayerNorm(depth))
layers.append(ChLayerNorm(out_dim))
if act:
layers.append(act())
[m.apply(initializer) for m in layers[-3:]]
@ -587,7 +595,9 @@ class ConvDecoder(nn.Module):
def forward(self, features, dtype=None):
x = self._linear_layer(features)
# (batch, time, -1) -> (batch * time, h, w, ch)
x = x.reshape([-1, 4, 4, self._embed_size // 16])
x = x.reshape(
[-1, self._minres, self._minres, self._embed_size // self._minres**2]
)
# (batch, time, -1) -> (batch * time, ch, h, w)
x = x.permute(0, 3, 1, 2)
x = self.layers(x)