expanded the supported image sizes
This commit is contained in:
parent
02c3d45fcf
commit
0faa10ff46
@ -52,9 +52,9 @@ defaults:
|
|||||||
act: 'SiLU'
|
act: 'SiLU'
|
||||||
norm: 'LayerNorm'
|
norm: 'LayerNorm'
|
||||||
encoder:
|
encoder:
|
||||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
||||||
decoder:
|
decoder:
|
||||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse,}
|
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse}
|
||||||
value_head: 'symlog_disc'
|
value_head: 'symlog_disc'
|
||||||
reward_head: 'symlog_disc'
|
reward_head: 'symlog_disc'
|
||||||
dyn_scale: '0.5'
|
dyn_scale: '0.5'
|
||||||
|
92
networks.py
92
networks.py
@ -337,7 +337,8 @@ class MultiEncoder(nn.Module):
|
|||||||
act,
|
act,
|
||||||
norm,
|
norm,
|
||||||
cnn_depth,
|
cnn_depth,
|
||||||
cnn_kernels,
|
kernel_size,
|
||||||
|
minres,
|
||||||
mlp_layers,
|
mlp_layers,
|
||||||
mlp_units,
|
mlp_units,
|
||||||
symlog_inputs,
|
symlog_inputs,
|
||||||
@ -359,7 +360,10 @@ class MultiEncoder(nn.Module):
|
|||||||
self.outdim = 0
|
self.outdim = 0
|
||||||
if self.cnn_shapes:
|
if self.cnn_shapes:
|
||||||
input_ch = sum([v[-1] for v in self.cnn_shapes.values()])
|
input_ch = sum([v[-1] for v in self.cnn_shapes.values()])
|
||||||
self._cnn = ConvEncoder(input_ch, cnn_depth, act, norm, cnn_kernels)
|
input_shape = tuple(self.cnn_shapes.values())[0][:2] + (input_ch,)
|
||||||
|
self._cnn = ConvEncoder(
|
||||||
|
input_shape, cnn_depth, act, norm, kernel_size, minres
|
||||||
|
)
|
||||||
self.outdim += self._cnn.outdim
|
self.outdim += self._cnn.outdim
|
||||||
if self.mlp_shapes:
|
if self.mlp_shapes:
|
||||||
input_size = sum([sum(v) for v in self.mlp_shapes.values()])
|
input_size = sum([sum(v) for v in self.mlp_shapes.values()])
|
||||||
@ -396,7 +400,8 @@ class MultiDecoder(nn.Module):
|
|||||||
act,
|
act,
|
||||||
norm,
|
norm,
|
||||||
cnn_depth,
|
cnn_depth,
|
||||||
cnn_kernels,
|
kernel_size,
|
||||||
|
minres,
|
||||||
mlp_layers,
|
mlp_layers,
|
||||||
mlp_units,
|
mlp_units,
|
||||||
cnn_sigmoid,
|
cnn_sigmoid,
|
||||||
@ -426,7 +431,8 @@ class MultiDecoder(nn.Module):
|
|||||||
cnn_depth,
|
cnn_depth,
|
||||||
act,
|
act,
|
||||||
norm,
|
norm,
|
||||||
cnn_kernels,
|
kernel_size,
|
||||||
|
minres,
|
||||||
cnn_sigmoid=cnn_sigmoid,
|
cnn_sigmoid=cnn_sigmoid,
|
||||||
)
|
)
|
||||||
if self.mlp_shapes:
|
if self.mlp_shapes:
|
||||||
@ -470,35 +476,39 @@ class MultiDecoder(nn.Module):
|
|||||||
|
|
||||||
class ConvEncoder(nn.Module):
|
class ConvEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, input_ch, depth=32, act="SiLU", norm="LayerNorm", kernels=(3, 3, 3, 3)
|
self,
|
||||||
|
input_shape,
|
||||||
|
depth=32,
|
||||||
|
act="SiLU",
|
||||||
|
norm="LayerNorm",
|
||||||
|
kernel_size=4,
|
||||||
|
minres=4,
|
||||||
):
|
):
|
||||||
super(ConvEncoder, self).__init__()
|
super(ConvEncoder, self).__init__()
|
||||||
act = getattr(torch.nn, act)
|
act = getattr(torch.nn, act)
|
||||||
norm = getattr(torch.nn, norm)
|
norm = getattr(torch.nn, norm)
|
||||||
self._depth = depth
|
h, w, input_ch = input_shape
|
||||||
self._kernels = kernels
|
|
||||||
h, w = 64, 64
|
|
||||||
layers = []
|
layers = []
|
||||||
for i, kernel in enumerate(self._kernels):
|
for i in range(int(np.log2(h) - np.log2(minres))):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
inp_dim = input_ch
|
in_dim = input_ch
|
||||||
else:
|
else:
|
||||||
inp_dim = 2 ** (i - 1) * self._depth
|
in_dim = 2 ** (i - 1) * depth
|
||||||
depth = 2**i * self._depth
|
out_dim = 2**i * depth
|
||||||
layers.append(
|
layers.append(
|
||||||
Conv2dSame(
|
Conv2dSame(
|
||||||
in_channels=inp_dim,
|
in_channels=in_dim,
|
||||||
out_channels=depth,
|
out_channels=out_dim,
|
||||||
kernel_size=(kernel, kernel),
|
kernel_size=kernel_size,
|
||||||
stride=(2, 2),
|
stride=2,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layers.append(ChLayerNorm(depth))
|
layers.append(ChLayerNorm(out_dim))
|
||||||
layers.append(act())
|
layers.append(act())
|
||||||
h, w = h // 2, w // 2
|
h, w = h // 2, w // 2
|
||||||
|
|
||||||
self.outdim = depth * h * w
|
self.outdim = out_dim * h * w
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
self.layers.apply(tools.weight_init)
|
self.layers.apply(tools.weight_init)
|
||||||
|
|
||||||
@ -517,52 +527,50 @@ class ConvEncoder(nn.Module):
|
|||||||
class ConvDecoder(nn.Module):
|
class ConvDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inp_depth,
|
feat_size,
|
||||||
shape=(3, 64, 64),
|
shape=(3, 64, 64),
|
||||||
depth=32,
|
depth=32,
|
||||||
act=nn.ELU,
|
act=nn.ELU,
|
||||||
norm=nn.LayerNorm,
|
norm=nn.LayerNorm,
|
||||||
kernels=(3, 3, 3, 3),
|
kernel_size=4,
|
||||||
|
minres=4,
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
cnn_sigmoid=False,
|
cnn_sigmoid=False,
|
||||||
):
|
):
|
||||||
super(ConvDecoder, self).__init__()
|
super(ConvDecoder, self).__init__()
|
||||||
self._inp_depth = inp_depth
|
|
||||||
act = getattr(torch.nn, act)
|
act = getattr(torch.nn, act)
|
||||||
norm = getattr(torch.nn, norm)
|
norm = getattr(torch.nn, norm)
|
||||||
self._depth = depth
|
|
||||||
self._shape = shape
|
self._shape = shape
|
||||||
self._kernels = kernels
|
|
||||||
self._cnn_sigmoid = cnn_sigmoid
|
self._cnn_sigmoid = cnn_sigmoid
|
||||||
self._embed_size = (
|
layer_num = int(np.log2(shape[1]) - np.log2(minres))
|
||||||
(64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)
|
self._minres = minres
|
||||||
)
|
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
|
||||||
|
|
||||||
self._linear_layer = nn.Linear(inp_depth, self._embed_size)
|
self._linear_layer = nn.Linear(feat_size, self._embed_size)
|
||||||
inp_dim = self._embed_size // 16
|
in_dim = self._embed_size // (minres**2)
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
h, w = 4, 4
|
h, w = minres, minres
|
||||||
for i, kernel in enumerate(self._kernels):
|
for i in range(layer_num):
|
||||||
depth = self._embed_size // 16 // (2 ** (i + 1))
|
out_dim = self._embed_size // (minres**2) // (2 ** (i + 1))
|
||||||
bias = False
|
bias = False
|
||||||
initializer = tools.weight_init
|
initializer = tools.weight_init
|
||||||
if i == len(self._kernels) - 1:
|
if i == layer_num - 1:
|
||||||
depth = self._shape[0]
|
out_dim = self._shape[0]
|
||||||
act = False
|
act = False
|
||||||
bias = True
|
bias = True
|
||||||
norm = False
|
norm = False
|
||||||
initializer = tools.uniform_weight_init(outscale)
|
initializer = tools.uniform_weight_init(outscale)
|
||||||
|
|
||||||
if i != 0:
|
if i != 0:
|
||||||
inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth
|
in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
|
||||||
pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1)
|
pad_h, outpad_h = self.calc_same_pad(k=kernel_size, s=2, d=1)
|
||||||
pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1)
|
pad_w, outpad_w = self.calc_same_pad(k=kernel_size, s=2, d=1)
|
||||||
layers.append(
|
layers.append(
|
||||||
nn.ConvTranspose2d(
|
nn.ConvTranspose2d(
|
||||||
inp_dim,
|
in_dim,
|
||||||
depth,
|
out_dim,
|
||||||
kernel,
|
kernel_size,
|
||||||
2,
|
2,
|
||||||
padding=(pad_h, pad_w),
|
padding=(pad_h, pad_w),
|
||||||
output_padding=(outpad_h, outpad_w),
|
output_padding=(outpad_h, outpad_w),
|
||||||
@ -570,7 +578,7 @@ class ConvDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if norm:
|
if norm:
|
||||||
layers.append(ChLayerNorm(depth))
|
layers.append(ChLayerNorm(out_dim))
|
||||||
if act:
|
if act:
|
||||||
layers.append(act())
|
layers.append(act())
|
||||||
[m.apply(initializer) for m in layers[-3:]]
|
[m.apply(initializer) for m in layers[-3:]]
|
||||||
@ -587,7 +595,9 @@ class ConvDecoder(nn.Module):
|
|||||||
def forward(self, features, dtype=None):
|
def forward(self, features, dtype=None):
|
||||||
x = self._linear_layer(features)
|
x = self._linear_layer(features)
|
||||||
# (batch, time, -1) -> (batch * time, h, w, ch)
|
# (batch, time, -1) -> (batch * time, h, w, ch)
|
||||||
x = x.reshape([-1, 4, 4, self._embed_size // 16])
|
x = x.reshape(
|
||||||
|
[-1, self._minres, self._minres, self._embed_size // self._minres**2]
|
||||||
|
)
|
||||||
# (batch, time, -1) -> (batch * time, ch, h, w)
|
# (batch, time, -1) -> (batch * time, ch, h, w)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user