expanded the supported image sizes
This commit is contained in:
parent
02c3d45fcf
commit
0faa10ff46
@ -52,9 +52,9 @@ defaults:
|
||||
act: 'SiLU'
|
||||
norm: 'LayerNorm'
|
||||
encoder:
|
||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
||||
decoder:
|
||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse,}
|
||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse}
|
||||
value_head: 'symlog_disc'
|
||||
reward_head: 'symlog_disc'
|
||||
dyn_scale: '0.5'
|
||||
|
92
networks.py
92
networks.py
@ -337,7 +337,8 @@ class MultiEncoder(nn.Module):
|
||||
act,
|
||||
norm,
|
||||
cnn_depth,
|
||||
cnn_kernels,
|
||||
kernel_size,
|
||||
minres,
|
||||
mlp_layers,
|
||||
mlp_units,
|
||||
symlog_inputs,
|
||||
@ -359,7 +360,10 @@ class MultiEncoder(nn.Module):
|
||||
self.outdim = 0
|
||||
if self.cnn_shapes:
|
||||
input_ch = sum([v[-1] for v in self.cnn_shapes.values()])
|
||||
self._cnn = ConvEncoder(input_ch, cnn_depth, act, norm, cnn_kernels)
|
||||
input_shape = tuple(self.cnn_shapes.values())[0][:2] + (input_ch,)
|
||||
self._cnn = ConvEncoder(
|
||||
input_shape, cnn_depth, act, norm, kernel_size, minres
|
||||
)
|
||||
self.outdim += self._cnn.outdim
|
||||
if self.mlp_shapes:
|
||||
input_size = sum([sum(v) for v in self.mlp_shapes.values()])
|
||||
@ -396,7 +400,8 @@ class MultiDecoder(nn.Module):
|
||||
act,
|
||||
norm,
|
||||
cnn_depth,
|
||||
cnn_kernels,
|
||||
kernel_size,
|
||||
minres,
|
||||
mlp_layers,
|
||||
mlp_units,
|
||||
cnn_sigmoid,
|
||||
@ -426,7 +431,8 @@ class MultiDecoder(nn.Module):
|
||||
cnn_depth,
|
||||
act,
|
||||
norm,
|
||||
cnn_kernels,
|
||||
kernel_size,
|
||||
minres,
|
||||
cnn_sigmoid=cnn_sigmoid,
|
||||
)
|
||||
if self.mlp_shapes:
|
||||
@ -470,35 +476,39 @@ class MultiDecoder(nn.Module):
|
||||
|
||||
class ConvEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, input_ch, depth=32, act="SiLU", norm="LayerNorm", kernels=(3, 3, 3, 3)
|
||||
self,
|
||||
input_shape,
|
||||
depth=32,
|
||||
act="SiLU",
|
||||
norm="LayerNorm",
|
||||
kernel_size=4,
|
||||
minres=4,
|
||||
):
|
||||
super(ConvEncoder, self).__init__()
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._depth = depth
|
||||
self._kernels = kernels
|
||||
h, w = 64, 64
|
||||
h, w, input_ch = input_shape
|
||||
layers = []
|
||||
for i, kernel in enumerate(self._kernels):
|
||||
for i in range(int(np.log2(h) - np.log2(minres))):
|
||||
if i == 0:
|
||||
inp_dim = input_ch
|
||||
in_dim = input_ch
|
||||
else:
|
||||
inp_dim = 2 ** (i - 1) * self._depth
|
||||
depth = 2**i * self._depth
|
||||
in_dim = 2 ** (i - 1) * depth
|
||||
out_dim = 2**i * depth
|
||||
layers.append(
|
||||
Conv2dSame(
|
||||
in_channels=inp_dim,
|
||||
out_channels=depth,
|
||||
kernel_size=(kernel, kernel),
|
||||
stride=(2, 2),
|
||||
in_channels=in_dim,
|
||||
out_channels=out_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=2,
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
layers.append(ChLayerNorm(depth))
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
layers.append(act())
|
||||
h, w = h // 2, w // 2
|
||||
|
||||
self.outdim = depth * h * w
|
||||
self.outdim = out_dim * h * w
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
@ -517,52 +527,50 @@ class ConvEncoder(nn.Module):
|
||||
class ConvDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inp_depth,
|
||||
feat_size,
|
||||
shape=(3, 64, 64),
|
||||
depth=32,
|
||||
act=nn.ELU,
|
||||
norm=nn.LayerNorm,
|
||||
kernels=(3, 3, 3, 3),
|
||||
kernel_size=4,
|
||||
minres=4,
|
||||
outscale=1.0,
|
||||
cnn_sigmoid=False,
|
||||
):
|
||||
super(ConvDecoder, self).__init__()
|
||||
self._inp_depth = inp_depth
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._depth = depth
|
||||
self._shape = shape
|
||||
self._kernels = kernels
|
||||
self._cnn_sigmoid = cnn_sigmoid
|
||||
self._embed_size = (
|
||||
(64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)
|
||||
)
|
||||
layer_num = int(np.log2(shape[1]) - np.log2(minres))
|
||||
self._minres = minres
|
||||
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
|
||||
|
||||
self._linear_layer = nn.Linear(inp_depth, self._embed_size)
|
||||
inp_dim = self._embed_size // 16
|
||||
self._linear_layer = nn.Linear(feat_size, self._embed_size)
|
||||
in_dim = self._embed_size // (minres**2)
|
||||
|
||||
layers = []
|
||||
h, w = 4, 4
|
||||
for i, kernel in enumerate(self._kernels):
|
||||
depth = self._embed_size // 16 // (2 ** (i + 1))
|
||||
h, w = minres, minres
|
||||
for i in range(layer_num):
|
||||
out_dim = self._embed_size // (minres**2) // (2 ** (i + 1))
|
||||
bias = False
|
||||
initializer = tools.weight_init
|
||||
if i == len(self._kernels) - 1:
|
||||
depth = self._shape[0]
|
||||
if i == layer_num - 1:
|
||||
out_dim = self._shape[0]
|
||||
act = False
|
||||
bias = True
|
||||
norm = False
|
||||
initializer = tools.uniform_weight_init(outscale)
|
||||
|
||||
if i != 0:
|
||||
inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth
|
||||
pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1)
|
||||
pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1)
|
||||
in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
|
||||
pad_h, outpad_h = self.calc_same_pad(k=kernel_size, s=2, d=1)
|
||||
pad_w, outpad_w = self.calc_same_pad(k=kernel_size, s=2, d=1)
|
||||
layers.append(
|
||||
nn.ConvTranspose2d(
|
||||
inp_dim,
|
||||
depth,
|
||||
kernel,
|
||||
in_dim,
|
||||
out_dim,
|
||||
kernel_size,
|
||||
2,
|
||||
padding=(pad_h, pad_w),
|
||||
output_padding=(outpad_h, outpad_w),
|
||||
@ -570,7 +578,7 @@ class ConvDecoder(nn.Module):
|
||||
)
|
||||
)
|
||||
if norm:
|
||||
layers.append(ChLayerNorm(depth))
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
if act:
|
||||
layers.append(act())
|
||||
[m.apply(initializer) for m in layers[-3:]]
|
||||
@ -587,7 +595,9 @@ class ConvDecoder(nn.Module):
|
||||
def forward(self, features, dtype=None):
|
||||
x = self._linear_layer(features)
|
||||
# (batch, time, -1) -> (batch * time, h, w, ch)
|
||||
x = x.reshape([-1, 4, 4, self._embed_size // 16])
|
||||
x = x.reshape(
|
||||
[-1, self._minres, self._minres, self._embed_size // self._minres**2]
|
||||
)
|
||||
# (batch, time, -1) -> (batch * time, ch, h, w)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.layers(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user