erased unused options
This commit is contained in:
parent
a27711ab96
commit
7f66ed5333
18
configs.yaml
18
configs.yaml
@ -17,7 +17,6 @@ defaults:
|
||||
compile: True
|
||||
precision: 32
|
||||
debug: False
|
||||
expl_gifs: False
|
||||
video_pred_log: True
|
||||
|
||||
# Environment
|
||||
@ -28,27 +27,21 @@ defaults:
|
||||
time_limit: 1000
|
||||
grayscale: False
|
||||
prefill: 2500
|
||||
eval_noise: 0.0
|
||||
reward_EMA: True
|
||||
|
||||
# Model
|
||||
dyn_cell: 'gru_layer_norm'
|
||||
dyn_hidden: 512
|
||||
dyn_deter: 512
|
||||
dyn_stoch: 32
|
||||
dyn_discrete: 32
|
||||
dyn_input_layers: 1
|
||||
dyn_output_layers: 1
|
||||
dyn_rec_depth: 1
|
||||
dyn_shared: False
|
||||
dyn_mean_act: 'none'
|
||||
dyn_std_act: 'sigmoid2'
|
||||
dyn_min_std: 0.1
|
||||
dyn_temp_post: True
|
||||
grad_heads: ['decoder', 'reward', 'cont']
|
||||
units: 512
|
||||
act: 'SiLU'
|
||||
norm: 'LayerNorm'
|
||||
norm: True
|
||||
encoder:
|
||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
||||
decoder:
|
||||
@ -58,9 +51,9 @@ defaults:
|
||||
critic:
|
||||
{layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0}
|
||||
reward_head:
|
||||
{layers: 2, dist: 'symlog_disc', scale: 1.0, outscale: 0.0}
|
||||
{layers: 2, dist: 'symlog_disc', loss_scale: 1.0, outscale: 0.0}
|
||||
cont_head:
|
||||
{layers: 2, scale: 1.0, outscale: 1.0}
|
||||
{layers: 2, loss_scale: 1.0, outscale: 1.0}
|
||||
dyn_scale: 0.5
|
||||
rep_scale: 0.1
|
||||
kl_free: 1.0
|
||||
@ -85,12 +78,7 @@ defaults:
|
||||
imag_horizon: 15
|
||||
imag_gradient: 'dynamics'
|
||||
imag_gradient_mix: 0.0
|
||||
imag_sample: True
|
||||
expl_amount: 0
|
||||
eval_state_mean: False
|
||||
collect_dyn_sample: True
|
||||
behavior_stop_grad: True
|
||||
future_entropy: False
|
||||
|
||||
# Exploration
|
||||
expl_behavior: 'greedy'
|
||||
|
19
dreamer.py
19
dreamer.py
@ -42,9 +42,7 @@ class Dreamer(nn.Module):
|
||||
self._update_count = 0
|
||||
self._dataset = dataset
|
||||
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
|
||||
self._task_behavior = models.ImagBehavior(
|
||||
config, self._wm, config.behavior_stop_grad
|
||||
)
|
||||
self._task_behavior = models.ImagBehavior(config, self._wm)
|
||||
if (
|
||||
config.compile and os.name != "nt"
|
||||
): # compilation is not supported on windows
|
||||
@ -92,9 +90,7 @@ class Dreamer(nn.Module):
|
||||
latent, action = state
|
||||
obs = self._wm.preprocess(obs)
|
||||
embed = self._wm.encoder(obs)
|
||||
latent, _ = self._wm.dynamics.obs_step(
|
||||
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
|
||||
)
|
||||
latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"])
|
||||
if self._config.eval_state_mean:
|
||||
latent["stoch"] = latent["mean"]
|
||||
feat = self._wm.dynamics.get_feat(latent)
|
||||
@ -114,21 +110,10 @@ class Dreamer(nn.Module):
|
||||
action = torch.one_hot(
|
||||
torch.argmax(action, dim=-1), self._config.num_actions
|
||||
)
|
||||
action = self._exploration(action, training)
|
||||
policy_output = {"action": action, "logprob": logprob}
|
||||
state = (latent, action)
|
||||
return policy_output, state
|
||||
|
||||
def _exploration(self, action, training):
|
||||
amount = self._config.expl_amount if training else self._config.eval_noise
|
||||
if amount == 0:
|
||||
return action
|
||||
if "onehot" in self._config.actor["dist"]:
|
||||
probs = amount / self._config.num_actions + (1 - amount) * action
|
||||
return tools.OneHotDist(probs=probs).sample()
|
||||
else:
|
||||
return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1)
|
||||
|
||||
def _train(self, data):
|
||||
metrics = {}
|
||||
post, context, mets = self._wm._train(data)
|
||||
|
@ -38,7 +38,7 @@ class Random(nn.Module):
|
||||
|
||||
|
||||
class Plan2Explore(nn.Module):
|
||||
def __init__(self, config, world_model, reward=None):
|
||||
def __init__(self, config, world_model, reward):
|
||||
super(Plan2Explore, self).__init__()
|
||||
self._config = config
|
||||
self._use_amp = True if config.precision == 16 else False
|
||||
|
79
models.py
79
models.py
@ -1,8 +1,6 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from PIL import ImageColor, Image, ImageDraw, ImageFont
|
||||
|
||||
import networks
|
||||
import tools
|
||||
@ -10,21 +8,21 @@ import tools
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
|
||||
class RewardEMA(object):
|
||||
class RewardEMA:
|
||||
"""running mean and std"""
|
||||
|
||||
def __init__(self, device, alpha=1e-2):
|
||||
self.device = device
|
||||
self.values = torch.zeros((2,)).to(device)
|
||||
self.alpha = alpha
|
||||
self.range = torch.tensor([0.05, 0.95]).to(device)
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x, ema_vals):
|
||||
flat_x = torch.flatten(x.detach())
|
||||
x_quantile = torch.quantile(input=flat_x, q=self.range)
|
||||
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
|
||||
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
|
||||
offset = self.values[0]
|
||||
# this should be in-place operation
|
||||
ema_vals[:] = self.alpha * x_quantile + (1 - self.alpha) * ema_vals
|
||||
scale = torch.clip(ema_vals[1] - ema_vals[0], min=1.0)
|
||||
offset = ema_vals[0]
|
||||
return offset.detach(), scale.detach()
|
||||
|
||||
|
||||
@ -41,18 +39,13 @@ class WorldModel(nn.Module):
|
||||
config.dyn_stoch,
|
||||
config.dyn_deter,
|
||||
config.dyn_hidden,
|
||||
config.dyn_input_layers,
|
||||
config.dyn_output_layers,
|
||||
config.dyn_rec_depth,
|
||||
config.dyn_shared,
|
||||
config.dyn_discrete,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.dyn_mean_act,
|
||||
config.dyn_std_act,
|
||||
config.dyn_temp_post,
|
||||
config.dyn_min_std,
|
||||
config.dyn_cell,
|
||||
config.unimix_ratio,
|
||||
config.initial,
|
||||
config.num_actions,
|
||||
@ -106,10 +99,10 @@ class WorldModel(nn.Module):
|
||||
print(
|
||||
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
|
||||
)
|
||||
# other losses are scaled by 1.0.
|
||||
self._scales = dict(
|
||||
reward=config.reward_head["scale"],
|
||||
cont=config.cont_head["scale"],
|
||||
image=1.0,
|
||||
reward=config.reward_head["loss_scale"],
|
||||
cont=config.cont_head["loss_scale"],
|
||||
)
|
||||
|
||||
def _train(self, data):
|
||||
@ -148,7 +141,8 @@ class WorldModel(nn.Module):
|
||||
assert loss.shape == embed.shape[:2], (name, loss.shape)
|
||||
losses[name] = loss
|
||||
scaled = {
|
||||
key: value * self._scales[key] for key, value in losses.items()
|
||||
key: value * self._scales.get(key, 1.0)
|
||||
for key, value in losses.items()
|
||||
}
|
||||
model_loss = sum(scaled.values()) + kl_loss
|
||||
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
|
||||
@ -217,13 +211,11 @@ class WorldModel(nn.Module):
|
||||
|
||||
|
||||
class ImagBehavior(nn.Module):
|
||||
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
|
||||
def __init__(self, config, world_model):
|
||||
super(ImagBehavior, self).__init__()
|
||||
self._use_amp = True if config.precision == 16 else False
|
||||
self._config = config
|
||||
self._world_model = world_model
|
||||
self._stop_grad_actor = stop_grad_actor
|
||||
self._reward = reward
|
||||
if config.dyn_discrete:
|
||||
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
||||
else:
|
||||
@ -284,42 +276,34 @@ class ImagBehavior(nn.Module):
|
||||
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
|
||||
)
|
||||
if self._config.reward_EMA:
|
||||
# register ema_vals to nn.Module for enabling torch.save and torch.load
|
||||
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
|
||||
self.reward_ema = RewardEMA(device=self._config.device)
|
||||
|
||||
def _train(
|
||||
self,
|
||||
start,
|
||||
objective=None,
|
||||
action=None,
|
||||
reward=None,
|
||||
imagine=None,
|
||||
tape=None,
|
||||
repeats=None,
|
||||
objective,
|
||||
):
|
||||
objective = objective or self._reward
|
||||
self._update_slow_target()
|
||||
metrics = {}
|
||||
|
||||
with tools.RequiresGrad(self.actor):
|
||||
with torch.cuda.amp.autocast(self._use_amp):
|
||||
imag_feat, imag_state, imag_action = self._imagine(
|
||||
start, self.actor, self._config.imag_horizon, repeats
|
||||
start, self.actor, self._config.imag_horizon
|
||||
)
|
||||
reward = objective(imag_feat, imag_state, imag_action)
|
||||
actor_ent = self.actor(imag_feat).entropy()
|
||||
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
|
||||
# this target is not scaled
|
||||
# slow is flag to indicate whether slow_target is used for lambda-return
|
||||
# this target is not scaled by ema or sym_log.
|
||||
target, weights, base = self._compute_target(
|
||||
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
||||
imag_feat, imag_state, reward
|
||||
)
|
||||
actor_loss, mets = self._compute_actor_loss(
|
||||
imag_feat,
|
||||
imag_state,
|
||||
imag_action,
|
||||
target,
|
||||
actor_ent,
|
||||
state_ent,
|
||||
weights,
|
||||
base,
|
||||
)
|
||||
@ -357,33 +341,27 @@ class ImagBehavior(nn.Module):
|
||||
metrics.update(self._value_opt(value_loss, self.value.parameters()))
|
||||
return imag_feat, imag_state, imag_action, weights, metrics
|
||||
|
||||
def _imagine(self, start, policy, horizon, repeats=None):
|
||||
def _imagine(self, start, policy, horizon):
|
||||
dynamics = self._world_model.dynamics
|
||||
if repeats:
|
||||
raise NotImplemented("repeats is not implemented in this version")
|
||||
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
|
||||
start = {k: flatten(v) for k, v in start.items()}
|
||||
|
||||
def step(prev, _):
|
||||
state, _, _ = prev
|
||||
feat = dynamics.get_feat(state)
|
||||
inp = feat.detach() if self._stop_grad_actor else feat
|
||||
inp = feat.detach()
|
||||
action = policy(inp).sample()
|
||||
succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
|
||||
succ = dynamics.img_step(state, action)
|
||||
return succ, feat, action
|
||||
|
||||
succ, feats, actions = tools.static_scan(
|
||||
step, [torch.arange(horizon)], (start, None, None)
|
||||
)
|
||||
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
|
||||
if repeats:
|
||||
raise NotImplemented("repeats is not implemented in this version")
|
||||
|
||||
return feats, states, actions
|
||||
|
||||
def _compute_target(
|
||||
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
||||
):
|
||||
def _compute_target(self, imag_feat, imag_state, reward):
|
||||
if "cont" in self._world_model.heads:
|
||||
inp = self._world_model.dynamics.get_feat(imag_state)
|
||||
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
||||
@ -406,29 +384,24 @@ class ImagBehavior(nn.Module):
|
||||
def _compute_actor_loss(
|
||||
self,
|
||||
imag_feat,
|
||||
imag_state,
|
||||
imag_action,
|
||||
target,
|
||||
actor_ent,
|
||||
state_ent,
|
||||
weights,
|
||||
base,
|
||||
):
|
||||
metrics = {}
|
||||
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
|
||||
inp = imag_feat.detach()
|
||||
policy = self.actor(inp)
|
||||
actor_ent = policy.entropy()
|
||||
# Q-val for actor is not transformed using symlog
|
||||
target = torch.stack(target, dim=1)
|
||||
if self._config.reward_EMA:
|
||||
offset, scale = self.reward_ema(target)
|
||||
offset, scale = self.reward_ema(target, self.ema_vals)
|
||||
normed_target = (target - offset) / scale
|
||||
normed_base = (base - offset) / scale
|
||||
adv = normed_target - normed_base
|
||||
metrics.update(tools.tensorstats(normed_target, "normed_target"))
|
||||
values = self.reward_ema.values
|
||||
metrics["EMA_005"] = to_np(values[0])
|
||||
metrics["EMA_095"] = to_np(values[1])
|
||||
metrics["EMA_005"] = to_np(self.ema_vals[0])
|
||||
metrics["EMA_095"] = to_np(self.ema_vals[1])
|
||||
|
||||
if self._config.imag_gradient == "dynamics":
|
||||
actor_target = adv
|
||||
|
146
networks.py
146
networks.py
@ -16,18 +16,13 @@ class RSSM(nn.Module):
|
||||
stoch=30,
|
||||
deter=200,
|
||||
hidden=200,
|
||||
layers_input=1,
|
||||
layers_output=1,
|
||||
rec_depth=1,
|
||||
shared=False,
|
||||
discrete=False,
|
||||
act="SiLU",
|
||||
norm="LayerNorm",
|
||||
norm=True,
|
||||
mean_act="none",
|
||||
std_act="softplus",
|
||||
temp_post=True,
|
||||
min_std=0.1,
|
||||
cell="gru",
|
||||
unimix_ratio=0.01,
|
||||
initial="learned",
|
||||
num_actions=None,
|
||||
@ -39,16 +34,11 @@ class RSSM(nn.Module):
|
||||
self._deter = deter
|
||||
self._hidden = hidden
|
||||
self._min_std = min_std
|
||||
self._layers_input = layers_input
|
||||
self._layers_output = layers_output
|
||||
self._rec_depth = rec_depth
|
||||
self._shared = shared
|
||||
self._discrete = discrete
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._mean_act = mean_act
|
||||
self._std_act = std_act
|
||||
self._temp_post = temp_post
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._initial = initial
|
||||
self._num_actions = num_actions
|
||||
@ -60,47 +50,30 @@ class RSSM(nn.Module):
|
||||
inp_dim = self._stoch * self._discrete + num_actions
|
||||
else:
|
||||
inp_dim = self._stoch + num_actions
|
||||
if self._shared:
|
||||
inp_dim += self._embed
|
||||
for i in range(self._layers_input):
|
||||
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
inp_layers.append(norm(self._hidden, eps=1e-03))
|
||||
inp_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
if norm:
|
||||
inp_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
|
||||
inp_layers.append(act())
|
||||
self._img_in_layers = nn.Sequential(*inp_layers)
|
||||
self._img_in_layers.apply(tools.weight_init)
|
||||
if cell == "gru":
|
||||
self._cell = GRUCell(self._hidden, self._deter)
|
||||
self._cell.apply(tools.weight_init)
|
||||
elif cell == "gru_layer_norm":
|
||||
self._cell = GRUCell(self._hidden, self._deter, norm=True)
|
||||
self._cell.apply(tools.weight_init)
|
||||
else:
|
||||
raise NotImplementedError(cell)
|
||||
self._cell = GRUCell(self._hidden, self._deter, norm=norm)
|
||||
self._cell.apply(tools.weight_init)
|
||||
|
||||
img_out_layers = []
|
||||
inp_dim = self._deter
|
||||
for i in range(self._layers_output):
|
||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
img_out_layers.append(norm(self._hidden, eps=1e-03))
|
||||
img_out_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
if norm:
|
||||
img_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
|
||||
img_out_layers.append(act())
|
||||
self._img_out_layers = nn.Sequential(*img_out_layers)
|
||||
self._img_out_layers.apply(tools.weight_init)
|
||||
|
||||
obs_out_layers = []
|
||||
if self._temp_post:
|
||||
inp_dim = self._deter + self._embed
|
||||
else:
|
||||
inp_dim = self._embed
|
||||
for i in range(self._layers_output):
|
||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
obs_out_layers.append(norm(self._hidden, eps=1e-03))
|
||||
obs_out_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
inp_dim = self._deter + self._embed
|
||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
if norm:
|
||||
obs_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
|
||||
obs_out_layers.append(act())
|
||||
self._obs_out_layers = nn.Sequential(*obs_out_layers)
|
||||
self._obs_out_layers.apply(tools.weight_init)
|
||||
|
||||
@ -200,9 +173,6 @@ class RSSM(nn.Module):
|
||||
return dist
|
||||
|
||||
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
|
||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||
|
||||
# initialize all prev_state
|
||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||
prev_state = self.initial(len(is_first))
|
||||
@ -223,41 +193,28 @@ class RSSM(nn.Module):
|
||||
val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||
)
|
||||
|
||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||
if self._shared:
|
||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
||||
prior = self.img_step(prev_state, prev_action)
|
||||
x = torch.cat([prior["deter"], embed], -1)
|
||||
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
|
||||
x = self._obs_out_layers(x)
|
||||
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
|
||||
stats = self._suff_stats_layer("obs", x)
|
||||
if sample:
|
||||
stoch = self.get_dist(stats).sample()
|
||||
else:
|
||||
if self._temp_post:
|
||||
x = torch.cat([prior["deter"], embed], -1)
|
||||
else:
|
||||
x = embed
|
||||
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
|
||||
x = self._obs_out_layers(x)
|
||||
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
|
||||
stats = self._suff_stats_layer("obs", x)
|
||||
if sample:
|
||||
stoch = self.get_dist(stats).sample()
|
||||
else:
|
||||
stoch = self.get_dist(stats).mode()
|
||||
post = {"stoch": stoch, "deter": prior["deter"], **stats}
|
||||
stoch = self.get_dist(stats).mode()
|
||||
post = {"stoch": stoch, "deter": prior["deter"], **stats}
|
||||
return post, prior
|
||||
|
||||
# this is used for making future image
|
||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
||||
def img_step(self, prev_state, prev_action, sample=True):
|
||||
# (batch, stoch, discrete_num)
|
||||
prev_stoch = prev_state["stoch"]
|
||||
if self._discrete:
|
||||
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
||||
# (batch, stoch, discrete_num) -> (batch, stoch * discrete_num)
|
||||
prev_stoch = prev_stoch.reshape(shape)
|
||||
if self._shared:
|
||||
if embed is None:
|
||||
shape = list(prev_action.shape[:-1]) + [self._embed]
|
||||
embed = torch.zeros(shape)
|
||||
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed)
|
||||
x = torch.cat([prev_stoch, prev_action, embed], -1)
|
||||
else:
|
||||
x = torch.cat([prev_stoch, prev_action], -1)
|
||||
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action)
|
||||
x = torch.cat([prev_stoch, prev_action], -1)
|
||||
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
|
||||
x = self._img_in_layers(x)
|
||||
for _ in range(self._rec_depth): # rec depth is not correctly implemented
|
||||
@ -508,7 +465,7 @@ class ConvEncoder(nn.Module):
|
||||
layers = []
|
||||
for i in range(stages):
|
||||
layers.append(
|
||||
Conv2dSame(
|
||||
Conv2dSamePad(
|
||||
in_channels=in_dim,
|
||||
out_channels=out_dim,
|
||||
kernel_size=kernel_size,
|
||||
@ -517,7 +474,7 @@ class ConvEncoder(nn.Module):
|
||||
)
|
||||
)
|
||||
if norm:
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
layers.append(ImgChLayerNorm(out_dim))
|
||||
layers.append(act())
|
||||
in_dim = out_dim
|
||||
out_dim *= 2
|
||||
@ -593,7 +550,7 @@ class ConvDecoder(nn.Module):
|
||||
)
|
||||
)
|
||||
if norm:
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
layers.append(ImgChLayerNorm(out_dim))
|
||||
if act:
|
||||
layers.append(act())
|
||||
in_dim = out_dim
|
||||
@ -637,7 +594,7 @@ class MLP(nn.Module):
|
||||
layers,
|
||||
units,
|
||||
act="SiLU",
|
||||
norm="LayerNorm",
|
||||
norm=True,
|
||||
dist="normal",
|
||||
std=1.0,
|
||||
min_std=0.1,
|
||||
@ -654,11 +611,9 @@ class MLP(nn.Module):
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
if self._shape is not None and len(self._shape) == 0:
|
||||
self._shape = (1,)
|
||||
self._layers = layers
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
self._std = std if isinstance(std, str) else torch.tensor((std,), device=device)
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._absmax = absmax
|
||||
@ -668,13 +623,16 @@ class MLP(nn.Module):
|
||||
self._device = device
|
||||
|
||||
self.layers = nn.Sequential()
|
||||
for index in range(self._layers):
|
||||
for i in range(layers):
|
||||
self.layers.add_module(
|
||||
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
|
||||
f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False)
|
||||
)
|
||||
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
|
||||
self.layers.add_module(f"{name}_act{index}", act())
|
||||
if index == 0:
|
||||
if norm:
|
||||
self.layers.add_module(
|
||||
f"{name}_norm{i}", nn.LayerNorm(units, eps=1e-03)
|
||||
)
|
||||
self.layers.add_module(f"{name}_act{i}", act())
|
||||
if i == 0:
|
||||
inp_dim = units
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
@ -783,16 +741,18 @@ class MLP(nn.Module):
|
||||
|
||||
|
||||
class GRUCell(nn.Module):
|
||||
def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1):
|
||||
def __init__(self, inp_size, size, norm=True, act=torch.tanh, update_bias=-1):
|
||||
super(GRUCell, self).__init__()
|
||||
self._inp_size = inp_size
|
||||
self._size = size
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
self._update_bias = update_bias
|
||||
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
|
||||
self.layers = nn.Sequential()
|
||||
self.layers.add_module(
|
||||
"GRU_linear", nn.Linear(inp_size + size, 3 * size, bias=False)
|
||||
)
|
||||
if norm:
|
||||
self._norm = nn.LayerNorm(3 * size, eps=1e-03)
|
||||
self.layers.add_module("GRU_norm", nn.LayerNorm(3 * size, eps=1e-03))
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -800,9 +760,7 @@ class GRUCell(nn.Module):
|
||||
|
||||
def forward(self, inputs, state):
|
||||
state = state[0] # Keras wraps the state in a list.
|
||||
parts = self._layer(torch.cat([inputs, state], -1))
|
||||
if self._norm:
|
||||
parts = self._norm(parts)
|
||||
parts = self.layers(torch.cat([inputs, state], -1))
|
||||
reset, cand, update = torch.split(parts, [self._size] * 3, -1)
|
||||
reset = torch.sigmoid(reset)
|
||||
cand = self._act(reset * cand)
|
||||
@ -811,7 +769,7 @@ class GRUCell(nn.Module):
|
||||
return output, [output]
|
||||
|
||||
|
||||
class Conv2dSame(torch.nn.Conv2d):
|
||||
class Conv2dSamePad(torch.nn.Conv2d):
|
||||
def calc_same_pad(self, i, k, s, d):
|
||||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
||||
|
||||
@ -841,9 +799,9 @@ class Conv2dSame(torch.nn.Conv2d):
|
||||
return ret
|
||||
|
||||
|
||||
class ChLayerNorm(nn.Module):
|
||||
class ImgChLayerNorm(nn.Module):
|
||||
def __init__(self, ch, eps=1e-03):
|
||||
super(ChLayerNorm, self).__init__()
|
||||
super(ImgChLayerNorm, self).__init__()
|
||||
self.norm = torch.nn.LayerNorm(ch, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
|
31
tools.py
31
tools.py
@ -840,37 +840,6 @@ def static_scan(fn, inputs, start):
|
||||
return outputs
|
||||
|
||||
|
||||
# Original version
|
||||
# def static_scan2(fn, inputs, start, reverse=False):
|
||||
# last = start
|
||||
# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
|
||||
# indices = range(inputs[0].shape[0])
|
||||
# if reverse:
|
||||
# indices = reversed(indices)
|
||||
# for index in indices:
|
||||
# inp = lambda x: (_input[x] for _input in inputs)
|
||||
# last = fn(last, *inp(index))
|
||||
# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)]
|
||||
# if reverse:
|
||||
# outputs = [list(reversed(x)) for x in outputs]
|
||||
# res = [[]] * len(outputs)
|
||||
# for i in range(len(outputs)):
|
||||
# if type(outputs[i][0]) == type({}):
|
||||
# _res = {}
|
||||
# for key in outputs[i][0].keys():
|
||||
# _res[key] = []
|
||||
# for j in range(len(outputs[i])):
|
||||
# _res[key].append(outputs[i][j][key])
|
||||
# #_res[key] = torch.stack(_res[key], 0)
|
||||
# _res[key] = faster_stack(_res[key], 0)
|
||||
# else:
|
||||
# _res = outputs[i]
|
||||
# #_res = torch.stack(_res, 0)
|
||||
# _res = faster_stack(_res, 0)
|
||||
# res[i] = _res
|
||||
# return res
|
||||
|
||||
|
||||
class Every:
|
||||
def __init__(self, every):
|
||||
self._every = every
|
||||
|
Loading…
x
Reference in New Issue
Block a user