erased unused options

This commit is contained in:
NM512 2024-01-05 23:23:09 +09:00
parent a27711ab96
commit 7f66ed5333
6 changed files with 84 additions and 211 deletions

View File

@ -17,7 +17,6 @@ defaults:
compile: True compile: True
precision: 32 precision: 32
debug: False debug: False
expl_gifs: False
video_pred_log: True video_pred_log: True
# Environment # Environment
@ -28,27 +27,21 @@ defaults:
time_limit: 1000 time_limit: 1000
grayscale: False grayscale: False
prefill: 2500 prefill: 2500
eval_noise: 0.0
reward_EMA: True reward_EMA: True
# Model # Model
dyn_cell: 'gru_layer_norm'
dyn_hidden: 512 dyn_hidden: 512
dyn_deter: 512 dyn_deter: 512
dyn_stoch: 32 dyn_stoch: 32
dyn_discrete: 32 dyn_discrete: 32
dyn_input_layers: 1
dyn_output_layers: 1
dyn_rec_depth: 1 dyn_rec_depth: 1
dyn_shared: False
dyn_mean_act: 'none' dyn_mean_act: 'none'
dyn_std_act: 'sigmoid2' dyn_std_act: 'sigmoid2'
dyn_min_std: 0.1 dyn_min_std: 0.1
dyn_temp_post: True
grad_heads: ['decoder', 'reward', 'cont'] grad_heads: ['decoder', 'reward', 'cont']
units: 512 units: 512
act: 'SiLU' act: 'SiLU'
norm: 'LayerNorm' norm: True
encoder: encoder:
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True} {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
decoder: decoder:
@ -58,9 +51,9 @@ defaults:
critic: critic:
{layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0} {layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0}
reward_head: reward_head:
{layers: 2, dist: 'symlog_disc', scale: 1.0, outscale: 0.0} {layers: 2, dist: 'symlog_disc', loss_scale: 1.0, outscale: 0.0}
cont_head: cont_head:
{layers: 2, scale: 1.0, outscale: 1.0} {layers: 2, loss_scale: 1.0, outscale: 1.0}
dyn_scale: 0.5 dyn_scale: 0.5
rep_scale: 0.1 rep_scale: 0.1
kl_free: 1.0 kl_free: 1.0
@ -85,12 +78,7 @@ defaults:
imag_horizon: 15 imag_horizon: 15
imag_gradient: 'dynamics' imag_gradient: 'dynamics'
imag_gradient_mix: 0.0 imag_gradient_mix: 0.0
imag_sample: True
expl_amount: 0
eval_state_mean: False eval_state_mean: False
collect_dyn_sample: True
behavior_stop_grad: True
future_entropy: False
# Exploration # Exploration
expl_behavior: 'greedy' expl_behavior: 'greedy'

View File

@ -42,9 +42,7 @@ class Dreamer(nn.Module):
self._update_count = 0 self._update_count = 0
self._dataset = dataset self._dataset = dataset
self._wm = models.WorldModel(obs_space, act_space, self._step, config) self._wm = models.WorldModel(obs_space, act_space, self._step, config)
self._task_behavior = models.ImagBehavior( self._task_behavior = models.ImagBehavior(config, self._wm)
config, self._wm, config.behavior_stop_grad
)
if ( if (
config.compile and os.name != "nt" config.compile and os.name != "nt"
): # compilation is not supported on windows ): # compilation is not supported on windows
@ -92,9 +90,7 @@ class Dreamer(nn.Module):
latent, action = state latent, action = state
obs = self._wm.preprocess(obs) obs = self._wm.preprocess(obs)
embed = self._wm.encoder(obs) embed = self._wm.encoder(obs)
latent, _ = self._wm.dynamics.obs_step( latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"])
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
)
if self._config.eval_state_mean: if self._config.eval_state_mean:
latent["stoch"] = latent["mean"] latent["stoch"] = latent["mean"]
feat = self._wm.dynamics.get_feat(latent) feat = self._wm.dynamics.get_feat(latent)
@ -114,21 +110,10 @@ class Dreamer(nn.Module):
action = torch.one_hot( action = torch.one_hot(
torch.argmax(action, dim=-1), self._config.num_actions torch.argmax(action, dim=-1), self._config.num_actions
) )
action = self._exploration(action, training)
policy_output = {"action": action, "logprob": logprob} policy_output = {"action": action, "logprob": logprob}
state = (latent, action) state = (latent, action)
return policy_output, state return policy_output, state
def _exploration(self, action, training):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
if "onehot" in self._config.actor["dist"]:
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
else:
return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1)
def _train(self, data): def _train(self, data):
metrics = {} metrics = {}
post, context, mets = self._wm._train(data) post, context, mets = self._wm._train(data)

View File

@ -38,7 +38,7 @@ class Random(nn.Module):
class Plan2Explore(nn.Module): class Plan2Explore(nn.Module):
def __init__(self, config, world_model, reward=None): def __init__(self, config, world_model, reward):
super(Plan2Explore, self).__init__() super(Plan2Explore, self).__init__()
self._config = config self._config = config
self._use_amp = True if config.precision == 16 else False self._use_amp = True if config.precision == 16 else False

View File

@ -1,8 +1,6 @@
import copy import copy
import torch import torch
from torch import nn from torch import nn
import numpy as np
from PIL import ImageColor, Image, ImageDraw, ImageFont
import networks import networks
import tools import tools
@ -10,21 +8,21 @@ import tools
to_np = lambda x: x.detach().cpu().numpy() to_np = lambda x: x.detach().cpu().numpy()
class RewardEMA(object): class RewardEMA:
"""running mean and std""" """running mean and std"""
def __init__(self, device, alpha=1e-2): def __init__(self, device, alpha=1e-2):
self.device = device self.device = device
self.values = torch.zeros((2,)).to(device)
self.alpha = alpha self.alpha = alpha
self.range = torch.tensor([0.05, 0.95]).to(device) self.range = torch.tensor([0.05, 0.95]).to(device)
def __call__(self, x): def __call__(self, x, ema_vals):
flat_x = torch.flatten(x.detach()) flat_x = torch.flatten(x.detach())
x_quantile = torch.quantile(input=flat_x, q=self.range) x_quantile = torch.quantile(input=flat_x, q=self.range)
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values # this should be in-place operation
scale = torch.clip(self.values[1] - self.values[0], min=1.0) ema_vals[:] = self.alpha * x_quantile + (1 - self.alpha) * ema_vals
offset = self.values[0] scale = torch.clip(ema_vals[1] - ema_vals[0], min=1.0)
offset = ema_vals[0]
return offset.detach(), scale.detach() return offset.detach(), scale.detach()
@ -41,18 +39,13 @@ class WorldModel(nn.Module):
config.dyn_stoch, config.dyn_stoch,
config.dyn_deter, config.dyn_deter,
config.dyn_hidden, config.dyn_hidden,
config.dyn_input_layers,
config.dyn_output_layers,
config.dyn_rec_depth, config.dyn_rec_depth,
config.dyn_shared,
config.dyn_discrete, config.dyn_discrete,
config.act, config.act,
config.norm, config.norm,
config.dyn_mean_act, config.dyn_mean_act,
config.dyn_std_act, config.dyn_std_act,
config.dyn_temp_post,
config.dyn_min_std, config.dyn_min_std,
config.dyn_cell,
config.unimix_ratio, config.unimix_ratio,
config.initial, config.initial,
config.num_actions, config.num_actions,
@ -106,10 +99,10 @@ class WorldModel(nn.Module):
print( print(
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables." f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
) )
# other losses are scaled by 1.0.
self._scales = dict( self._scales = dict(
reward=config.reward_head["scale"], reward=config.reward_head["loss_scale"],
cont=config.cont_head["scale"], cont=config.cont_head["loss_scale"],
image=1.0,
) )
def _train(self, data): def _train(self, data):
@ -148,7 +141,8 @@ class WorldModel(nn.Module):
assert loss.shape == embed.shape[:2], (name, loss.shape) assert loss.shape == embed.shape[:2], (name, loss.shape)
losses[name] = loss losses[name] = loss
scaled = { scaled = {
key: value * self._scales[key] for key, value in losses.items() key: value * self._scales.get(key, 1.0)
for key, value in losses.items()
} }
model_loss = sum(scaled.values()) + kl_loss model_loss = sum(scaled.values()) + kl_loss
metrics = self._model_opt(torch.mean(model_loss), self.parameters()) metrics = self._model_opt(torch.mean(model_loss), self.parameters())
@ -217,13 +211,11 @@ class WorldModel(nn.Module):
class ImagBehavior(nn.Module): class ImagBehavior(nn.Module):
def __init__(self, config, world_model, stop_grad_actor=True, reward=None): def __init__(self, config, world_model):
super(ImagBehavior, self).__init__() super(ImagBehavior, self).__init__()
self._use_amp = True if config.precision == 16 else False self._use_amp = True if config.precision == 16 else False
self._config = config self._config = config
self._world_model = world_model self._world_model = world_model
self._stop_grad_actor = stop_grad_actor
self._reward = reward
if config.dyn_discrete: if config.dyn_discrete:
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
else: else:
@ -284,42 +276,34 @@ class ImagBehavior(nn.Module):
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables." f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
) )
if self._config.reward_EMA: if self._config.reward_EMA:
# register ema_vals to nn.Module for enabling torch.save and torch.load
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
self.reward_ema = RewardEMA(device=self._config.device) self.reward_ema = RewardEMA(device=self._config.device)
def _train( def _train(
self, self,
start, start,
objective=None, objective,
action=None,
reward=None,
imagine=None,
tape=None,
repeats=None,
): ):
objective = objective or self._reward
self._update_slow_target() self._update_slow_target()
metrics = {} metrics = {}
with tools.RequiresGrad(self.actor): with tools.RequiresGrad(self.actor):
with torch.cuda.amp.autocast(self._use_amp): with torch.cuda.amp.autocast(self._use_amp):
imag_feat, imag_state, imag_action = self._imagine( imag_feat, imag_state, imag_action = self._imagine(
start, self.actor, self._config.imag_horizon, repeats start, self.actor, self._config.imag_horizon
) )
reward = objective(imag_feat, imag_state, imag_action) reward = objective(imag_feat, imag_state, imag_action)
actor_ent = self.actor(imag_feat).entropy() actor_ent = self.actor(imag_feat).entropy()
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy() state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
# this target is not scaled # this target is not scaled by ema or sym_log.
# slow is flag to indicate whether slow_target is used for lambda-return
target, weights, base = self._compute_target( target, weights, base = self._compute_target(
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent imag_feat, imag_state, reward
) )
actor_loss, mets = self._compute_actor_loss( actor_loss, mets = self._compute_actor_loss(
imag_feat, imag_feat,
imag_state,
imag_action, imag_action,
target, target,
actor_ent,
state_ent,
weights, weights,
base, base,
) )
@ -357,33 +341,27 @@ class ImagBehavior(nn.Module):
metrics.update(self._value_opt(value_loss, self.value.parameters())) metrics.update(self._value_opt(value_loss, self.value.parameters()))
return imag_feat, imag_state, imag_action, weights, metrics return imag_feat, imag_state, imag_action, weights, metrics
def _imagine(self, start, policy, horizon, repeats=None): def _imagine(self, start, policy, horizon):
dynamics = self._world_model.dynamics dynamics = self._world_model.dynamics
if repeats:
raise NotImplemented("repeats is not implemented in this version")
flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()} start = {k: flatten(v) for k, v in start.items()}
def step(prev, _): def step(prev, _):
state, _, _ = prev state, _, _ = prev
feat = dynamics.get_feat(state) feat = dynamics.get_feat(state)
inp = feat.detach() if self._stop_grad_actor else feat inp = feat.detach()
action = policy(inp).sample() action = policy(inp).sample()
succ = dynamics.img_step(state, action, sample=self._config.imag_sample) succ = dynamics.img_step(state, action)
return succ, feat, action return succ, feat, action
succ, feats, actions = tools.static_scan( succ, feats, actions = tools.static_scan(
step, [torch.arange(horizon)], (start, None, None) step, [torch.arange(horizon)], (start, None, None)
) )
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
if repeats:
raise NotImplemented("repeats is not implemented in this version")
return feats, states, actions return feats, states, actions
def _compute_target( def _compute_target(self, imag_feat, imag_state, reward):
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
):
if "cont" in self._world_model.heads: if "cont" in self._world_model.heads:
inp = self._world_model.dynamics.get_feat(imag_state) inp = self._world_model.dynamics.get_feat(imag_state)
discount = self._config.discount * self._world_model.heads["cont"](inp).mean discount = self._config.discount * self._world_model.heads["cont"](inp).mean
@ -406,29 +384,24 @@ class ImagBehavior(nn.Module):
def _compute_actor_loss( def _compute_actor_loss(
self, self,
imag_feat, imag_feat,
imag_state,
imag_action, imag_action,
target, target,
actor_ent,
state_ent,
weights, weights,
base, base,
): ):
metrics = {} metrics = {}
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat inp = imag_feat.detach()
policy = self.actor(inp) policy = self.actor(inp)
actor_ent = policy.entropy()
# Q-val for actor is not transformed using symlog # Q-val for actor is not transformed using symlog
target = torch.stack(target, dim=1) target = torch.stack(target, dim=1)
if self._config.reward_EMA: if self._config.reward_EMA:
offset, scale = self.reward_ema(target) offset, scale = self.reward_ema(target, self.ema_vals)
normed_target = (target - offset) / scale normed_target = (target - offset) / scale
normed_base = (base - offset) / scale normed_base = (base - offset) / scale
adv = normed_target - normed_base adv = normed_target - normed_base
metrics.update(tools.tensorstats(normed_target, "normed_target")) metrics.update(tools.tensorstats(normed_target, "normed_target"))
values = self.reward_ema.values metrics["EMA_005"] = to_np(self.ema_vals[0])
metrics["EMA_005"] = to_np(values[0]) metrics["EMA_095"] = to_np(self.ema_vals[1])
metrics["EMA_095"] = to_np(values[1])
if self._config.imag_gradient == "dynamics": if self._config.imag_gradient == "dynamics":
actor_target = adv actor_target = adv

View File

@ -16,18 +16,13 @@ class RSSM(nn.Module):
stoch=30, stoch=30,
deter=200, deter=200,
hidden=200, hidden=200,
layers_input=1,
layers_output=1,
rec_depth=1, rec_depth=1,
shared=False,
discrete=False, discrete=False,
act="SiLU", act="SiLU",
norm="LayerNorm", norm=True,
mean_act="none", mean_act="none",
std_act="softplus", std_act="softplus",
temp_post=True,
min_std=0.1, min_std=0.1,
cell="gru",
unimix_ratio=0.01, unimix_ratio=0.01,
initial="learned", initial="learned",
num_actions=None, num_actions=None,
@ -39,16 +34,11 @@ class RSSM(nn.Module):
self._deter = deter self._deter = deter
self._hidden = hidden self._hidden = hidden
self._min_std = min_std self._min_std = min_std
self._layers_input = layers_input
self._layers_output = layers_output
self._rec_depth = rec_depth self._rec_depth = rec_depth
self._shared = shared
self._discrete = discrete self._discrete = discrete
act = getattr(torch.nn, act) act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._mean_act = mean_act self._mean_act = mean_act
self._std_act = std_act self._std_act = std_act
self._temp_post = temp_post
self._unimix_ratio = unimix_ratio self._unimix_ratio = unimix_ratio
self._initial = initial self._initial = initial
self._num_actions = num_actions self._num_actions = num_actions
@ -60,47 +50,30 @@ class RSSM(nn.Module):
inp_dim = self._stoch * self._discrete + num_actions inp_dim = self._stoch * self._discrete + num_actions
else: else:
inp_dim = self._stoch + num_actions inp_dim = self._stoch + num_actions
if self._shared:
inp_dim += self._embed
for i in range(self._layers_input):
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
inp_layers.append(norm(self._hidden, eps=1e-03)) if norm:
inp_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
inp_layers.append(act()) inp_layers.append(act())
if i == 0:
inp_dim = self._hidden
self._img_in_layers = nn.Sequential(*inp_layers) self._img_in_layers = nn.Sequential(*inp_layers)
self._img_in_layers.apply(tools.weight_init) self._img_in_layers.apply(tools.weight_init)
if cell == "gru": self._cell = GRUCell(self._hidden, self._deter, norm=norm)
self._cell = GRUCell(self._hidden, self._deter)
self._cell.apply(tools.weight_init) self._cell.apply(tools.weight_init)
elif cell == "gru_layer_norm":
self._cell = GRUCell(self._hidden, self._deter, norm=True)
self._cell.apply(tools.weight_init)
else:
raise NotImplementedError(cell)
img_out_layers = [] img_out_layers = []
inp_dim = self._deter inp_dim = self._deter
for i in range(self._layers_output):
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
img_out_layers.append(norm(self._hidden, eps=1e-03)) if norm:
img_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
img_out_layers.append(act()) img_out_layers.append(act())
if i == 0:
inp_dim = self._hidden
self._img_out_layers = nn.Sequential(*img_out_layers) self._img_out_layers = nn.Sequential(*img_out_layers)
self._img_out_layers.apply(tools.weight_init) self._img_out_layers.apply(tools.weight_init)
obs_out_layers = [] obs_out_layers = []
if self._temp_post:
inp_dim = self._deter + self._embed inp_dim = self._deter + self._embed
else:
inp_dim = self._embed
for i in range(self._layers_output):
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
obs_out_layers.append(norm(self._hidden, eps=1e-03)) if norm:
obs_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
obs_out_layers.append(act()) obs_out_layers.append(act())
if i == 0:
inp_dim = self._hidden
self._obs_out_layers = nn.Sequential(*obs_out_layers) self._obs_out_layers = nn.Sequential(*obs_out_layers)
self._obs_out_layers.apply(tools.weight_init) self._obs_out_layers.apply(tools.weight_init)
@ -200,9 +173,6 @@ class RSSM(nn.Module):
return dist return dist
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True): def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
# initialize all prev_state # initialize all prev_state
if prev_state == None or torch.sum(is_first) == len(is_first): if prev_state == None or torch.sum(is_first) == len(is_first):
prev_state = self.initial(len(is_first)) prev_state = self.initial(len(is_first))
@ -223,14 +193,8 @@ class RSSM(nn.Module):
val * (1.0 - is_first_r) + init_state[key] * is_first_r val * (1.0 - is_first_r) + init_state[key] * is_first_r
) )
prior = self.img_step(prev_state, prev_action, None, sample) prior = self.img_step(prev_state, prev_action)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
else:
if self._temp_post:
x = torch.cat([prior["deter"], embed], -1) x = torch.cat([prior["deter"], embed], -1)
else:
x = embed
# (batch_size, prior_deter + embed) -> (batch_size, hidden) # (batch_size, prior_deter + embed) -> (batch_size, hidden)
x = self._obs_out_layers(x) x = self._obs_out_layers(x)
# (batch_size, hidden) -> (batch_size, stoch, discrete_num) # (batch_size, hidden) -> (batch_size, stoch, discrete_num)
@ -242,21 +206,14 @@ class RSSM(nn.Module):
post = {"stoch": stoch, "deter": prior["deter"], **stats} post = {"stoch": stoch, "deter": prior["deter"], **stats}
return post, prior return post, prior
# this is used for making future image def img_step(self, prev_state, prev_action, sample=True):
def img_step(self, prev_state, prev_action, embed=None, sample=True):
# (batch, stoch, discrete_num) # (batch, stoch, discrete_num)
prev_stoch = prev_state["stoch"] prev_stoch = prev_state["stoch"]
if self._discrete: if self._discrete:
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
# (batch, stoch, discrete_num) -> (batch, stoch * discrete_num) # (batch, stoch, discrete_num) -> (batch, stoch * discrete_num)
prev_stoch = prev_stoch.reshape(shape) prev_stoch = prev_stoch.reshape(shape)
if self._shared: # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action)
if embed is None:
shape = list(prev_action.shape[:-1]) + [self._embed]
embed = torch.zeros(shape)
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed)
x = torch.cat([prev_stoch, prev_action, embed], -1)
else:
x = torch.cat([prev_stoch, prev_action], -1) x = torch.cat([prev_stoch, prev_action], -1)
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden) # (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
x = self._img_in_layers(x) x = self._img_in_layers(x)
@ -508,7 +465,7 @@ class ConvEncoder(nn.Module):
layers = [] layers = []
for i in range(stages): for i in range(stages):
layers.append( layers.append(
Conv2dSame( Conv2dSamePad(
in_channels=in_dim, in_channels=in_dim,
out_channels=out_dim, out_channels=out_dim,
kernel_size=kernel_size, kernel_size=kernel_size,
@ -517,7 +474,7 @@ class ConvEncoder(nn.Module):
) )
) )
if norm: if norm:
layers.append(ChLayerNorm(out_dim)) layers.append(ImgChLayerNorm(out_dim))
layers.append(act()) layers.append(act())
in_dim = out_dim in_dim = out_dim
out_dim *= 2 out_dim *= 2
@ -593,7 +550,7 @@ class ConvDecoder(nn.Module):
) )
) )
if norm: if norm:
layers.append(ChLayerNorm(out_dim)) layers.append(ImgChLayerNorm(out_dim))
if act: if act:
layers.append(act()) layers.append(act())
in_dim = out_dim in_dim = out_dim
@ -637,7 +594,7 @@ class MLP(nn.Module):
layers, layers,
units, units,
act="SiLU", act="SiLU",
norm="LayerNorm", norm=True,
dist="normal", dist="normal",
std=1.0, std=1.0,
min_std=0.1, min_std=0.1,
@ -654,11 +611,9 @@ class MLP(nn.Module):
self._shape = (shape,) if isinstance(shape, int) else shape self._shape = (shape,) if isinstance(shape, int) else shape
if self._shape is not None and len(self._shape) == 0: if self._shape is not None and len(self._shape) == 0:
self._shape = (1,) self._shape = (1,)
self._layers = layers
act = getattr(torch.nn, act) act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._dist = dist self._dist = dist
self._std = std self._std = std if isinstance(std, str) else torch.tensor((std,), device=device)
self._min_std = min_std self._min_std = min_std
self._max_std = max_std self._max_std = max_std
self._absmax = absmax self._absmax = absmax
@ -668,13 +623,16 @@ class MLP(nn.Module):
self._device = device self._device = device
self.layers = nn.Sequential() self.layers = nn.Sequential()
for index in range(self._layers): for i in range(layers):
self.layers.add_module( self.layers.add_module(
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False) f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False)
) )
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03)) if norm:
self.layers.add_module(f"{name}_act{index}", act()) self.layers.add_module(
if index == 0: f"{name}_norm{i}", nn.LayerNorm(units, eps=1e-03)
)
self.layers.add_module(f"{name}_act{i}", act())
if i == 0:
inp_dim = units inp_dim = units
self.layers.apply(tools.weight_init) self.layers.apply(tools.weight_init)
@ -783,16 +741,18 @@ class MLP(nn.Module):
class GRUCell(nn.Module): class GRUCell(nn.Module):
def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1): def __init__(self, inp_size, size, norm=True, act=torch.tanh, update_bias=-1):
super(GRUCell, self).__init__() super(GRUCell, self).__init__()
self._inp_size = inp_size self._inp_size = inp_size
self._size = size self._size = size
self._act = act self._act = act
self._norm = norm
self._update_bias = update_bias self._update_bias = update_bias
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False) self.layers = nn.Sequential()
self.layers.add_module(
"GRU_linear", nn.Linear(inp_size + size, 3 * size, bias=False)
)
if norm: if norm:
self._norm = nn.LayerNorm(3 * size, eps=1e-03) self.layers.add_module("GRU_norm", nn.LayerNorm(3 * size, eps=1e-03))
@property @property
def state_size(self): def state_size(self):
@ -800,9 +760,7 @@ class GRUCell(nn.Module):
def forward(self, inputs, state): def forward(self, inputs, state):
state = state[0] # Keras wraps the state in a list. state = state[0] # Keras wraps the state in a list.
parts = self._layer(torch.cat([inputs, state], -1)) parts = self.layers(torch.cat([inputs, state], -1))
if self._norm:
parts = self._norm(parts)
reset, cand, update = torch.split(parts, [self._size] * 3, -1) reset, cand, update = torch.split(parts, [self._size] * 3, -1)
reset = torch.sigmoid(reset) reset = torch.sigmoid(reset)
cand = self._act(reset * cand) cand = self._act(reset * cand)
@ -811,7 +769,7 @@ class GRUCell(nn.Module):
return output, [output] return output, [output]
class Conv2dSame(torch.nn.Conv2d): class Conv2dSamePad(torch.nn.Conv2d):
def calc_same_pad(self, i, k, s, d): def calc_same_pad(self, i, k, s, d):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
@ -841,9 +799,9 @@ class Conv2dSame(torch.nn.Conv2d):
return ret return ret
class ChLayerNorm(nn.Module): class ImgChLayerNorm(nn.Module):
def __init__(self, ch, eps=1e-03): def __init__(self, ch, eps=1e-03):
super(ChLayerNorm, self).__init__() super(ImgChLayerNorm, self).__init__()
self.norm = torch.nn.LayerNorm(ch, eps=eps) self.norm = torch.nn.LayerNorm(ch, eps=eps)
def forward(self, x): def forward(self, x):

View File

@ -840,37 +840,6 @@ def static_scan(fn, inputs, start):
return outputs return outputs
# Original version
# def static_scan2(fn, inputs, start, reverse=False):
# last = start
# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
# indices = range(inputs[0].shape[0])
# if reverse:
# indices = reversed(indices)
# for index in indices:
# inp = lambda x: (_input[x] for _input in inputs)
# last = fn(last, *inp(index))
# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)]
# if reverse:
# outputs = [list(reversed(x)) for x in outputs]
# res = [[]] * len(outputs)
# for i in range(len(outputs)):
# if type(outputs[i][0]) == type({}):
# _res = {}
# for key in outputs[i][0].keys():
# _res[key] = []
# for j in range(len(outputs[i])):
# _res[key].append(outputs[i][j][key])
# #_res[key] = torch.stack(_res[key], 0)
# _res[key] = faster_stack(_res[key], 0)
# else:
# _res = outputs[i]
# #_res = torch.stack(_res, 0)
# _res = faster_stack(_res, 0)
# res[i] = _res
# return res
class Every: class Every:
def __init__(self, every): def __init__(self, every):
self._every = every self._every = every