erased unused options
This commit is contained in:
parent
a27711ab96
commit
7f66ed5333
18
configs.yaml
18
configs.yaml
@ -17,7 +17,6 @@ defaults:
|
|||||||
compile: True
|
compile: True
|
||||||
precision: 32
|
precision: 32
|
||||||
debug: False
|
debug: False
|
||||||
expl_gifs: False
|
|
||||||
video_pred_log: True
|
video_pred_log: True
|
||||||
|
|
||||||
# Environment
|
# Environment
|
||||||
@ -28,27 +27,21 @@ defaults:
|
|||||||
time_limit: 1000
|
time_limit: 1000
|
||||||
grayscale: False
|
grayscale: False
|
||||||
prefill: 2500
|
prefill: 2500
|
||||||
eval_noise: 0.0
|
|
||||||
reward_EMA: True
|
reward_EMA: True
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
dyn_cell: 'gru_layer_norm'
|
|
||||||
dyn_hidden: 512
|
dyn_hidden: 512
|
||||||
dyn_deter: 512
|
dyn_deter: 512
|
||||||
dyn_stoch: 32
|
dyn_stoch: 32
|
||||||
dyn_discrete: 32
|
dyn_discrete: 32
|
||||||
dyn_input_layers: 1
|
|
||||||
dyn_output_layers: 1
|
|
||||||
dyn_rec_depth: 1
|
dyn_rec_depth: 1
|
||||||
dyn_shared: False
|
|
||||||
dyn_mean_act: 'none'
|
dyn_mean_act: 'none'
|
||||||
dyn_std_act: 'sigmoid2'
|
dyn_std_act: 'sigmoid2'
|
||||||
dyn_min_std: 0.1
|
dyn_min_std: 0.1
|
||||||
dyn_temp_post: True
|
|
||||||
grad_heads: ['decoder', 'reward', 'cont']
|
grad_heads: ['decoder', 'reward', 'cont']
|
||||||
units: 512
|
units: 512
|
||||||
act: 'SiLU'
|
act: 'SiLU'
|
||||||
norm: 'LayerNorm'
|
norm: True
|
||||||
encoder:
|
encoder:
|
||||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
||||||
decoder:
|
decoder:
|
||||||
@ -58,9 +51,9 @@ defaults:
|
|||||||
critic:
|
critic:
|
||||||
{layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0}
|
{layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0}
|
||||||
reward_head:
|
reward_head:
|
||||||
{layers: 2, dist: 'symlog_disc', scale: 1.0, outscale: 0.0}
|
{layers: 2, dist: 'symlog_disc', loss_scale: 1.0, outscale: 0.0}
|
||||||
cont_head:
|
cont_head:
|
||||||
{layers: 2, scale: 1.0, outscale: 1.0}
|
{layers: 2, loss_scale: 1.0, outscale: 1.0}
|
||||||
dyn_scale: 0.5
|
dyn_scale: 0.5
|
||||||
rep_scale: 0.1
|
rep_scale: 0.1
|
||||||
kl_free: 1.0
|
kl_free: 1.0
|
||||||
@ -85,12 +78,7 @@ defaults:
|
|||||||
imag_horizon: 15
|
imag_horizon: 15
|
||||||
imag_gradient: 'dynamics'
|
imag_gradient: 'dynamics'
|
||||||
imag_gradient_mix: 0.0
|
imag_gradient_mix: 0.0
|
||||||
imag_sample: True
|
|
||||||
expl_amount: 0
|
|
||||||
eval_state_mean: False
|
eval_state_mean: False
|
||||||
collect_dyn_sample: True
|
|
||||||
behavior_stop_grad: True
|
|
||||||
future_entropy: False
|
|
||||||
|
|
||||||
# Exploration
|
# Exploration
|
||||||
expl_behavior: 'greedy'
|
expl_behavior: 'greedy'
|
||||||
|
19
dreamer.py
19
dreamer.py
@ -42,9 +42,7 @@ class Dreamer(nn.Module):
|
|||||||
self._update_count = 0
|
self._update_count = 0
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
|
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
|
||||||
self._task_behavior = models.ImagBehavior(
|
self._task_behavior = models.ImagBehavior(config, self._wm)
|
||||||
config, self._wm, config.behavior_stop_grad
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
config.compile and os.name != "nt"
|
config.compile and os.name != "nt"
|
||||||
): # compilation is not supported on windows
|
): # compilation is not supported on windows
|
||||||
@ -92,9 +90,7 @@ class Dreamer(nn.Module):
|
|||||||
latent, action = state
|
latent, action = state
|
||||||
obs = self._wm.preprocess(obs)
|
obs = self._wm.preprocess(obs)
|
||||||
embed = self._wm.encoder(obs)
|
embed = self._wm.encoder(obs)
|
||||||
latent, _ = self._wm.dynamics.obs_step(
|
latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"])
|
||||||
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
|
|
||||||
)
|
|
||||||
if self._config.eval_state_mean:
|
if self._config.eval_state_mean:
|
||||||
latent["stoch"] = latent["mean"]
|
latent["stoch"] = latent["mean"]
|
||||||
feat = self._wm.dynamics.get_feat(latent)
|
feat = self._wm.dynamics.get_feat(latent)
|
||||||
@ -114,21 +110,10 @@ class Dreamer(nn.Module):
|
|||||||
action = torch.one_hot(
|
action = torch.one_hot(
|
||||||
torch.argmax(action, dim=-1), self._config.num_actions
|
torch.argmax(action, dim=-1), self._config.num_actions
|
||||||
)
|
)
|
||||||
action = self._exploration(action, training)
|
|
||||||
policy_output = {"action": action, "logprob": logprob}
|
policy_output = {"action": action, "logprob": logprob}
|
||||||
state = (latent, action)
|
state = (latent, action)
|
||||||
return policy_output, state
|
return policy_output, state
|
||||||
|
|
||||||
def _exploration(self, action, training):
|
|
||||||
amount = self._config.expl_amount if training else self._config.eval_noise
|
|
||||||
if amount == 0:
|
|
||||||
return action
|
|
||||||
if "onehot" in self._config.actor["dist"]:
|
|
||||||
probs = amount / self._config.num_actions + (1 - amount) * action
|
|
||||||
return tools.OneHotDist(probs=probs).sample()
|
|
||||||
else:
|
|
||||||
return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1)
|
|
||||||
|
|
||||||
def _train(self, data):
|
def _train(self, data):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
post, context, mets = self._wm._train(data)
|
post, context, mets = self._wm._train(data)
|
||||||
|
@ -38,7 +38,7 @@ class Random(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Plan2Explore(nn.Module):
|
class Plan2Explore(nn.Module):
|
||||||
def __init__(self, config, world_model, reward=None):
|
def __init__(self, config, world_model, reward):
|
||||||
super(Plan2Explore, self).__init__()
|
super(Plan2Explore, self).__init__()
|
||||||
self._config = config
|
self._config = config
|
||||||
self._use_amp = True if config.precision == 16 else False
|
self._use_amp = True if config.precision == 16 else False
|
||||||
|
79
models.py
79
models.py
@ -1,8 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import numpy as np
|
|
||||||
from PIL import ImageColor, Image, ImageDraw, ImageFont
|
|
||||||
|
|
||||||
import networks
|
import networks
|
||||||
import tools
|
import tools
|
||||||
@ -10,21 +8,21 @@ import tools
|
|||||||
to_np = lambda x: x.detach().cpu().numpy()
|
to_np = lambda x: x.detach().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
class RewardEMA(object):
|
class RewardEMA:
|
||||||
"""running mean and std"""
|
"""running mean and std"""
|
||||||
|
|
||||||
def __init__(self, device, alpha=1e-2):
|
def __init__(self, device, alpha=1e-2):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.values = torch.zeros((2,)).to(device)
|
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.range = torch.tensor([0.05, 0.95]).to(device)
|
self.range = torch.tensor([0.05, 0.95]).to(device)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x, ema_vals):
|
||||||
flat_x = torch.flatten(x.detach())
|
flat_x = torch.flatten(x.detach())
|
||||||
x_quantile = torch.quantile(input=flat_x, q=self.range)
|
x_quantile = torch.quantile(input=flat_x, q=self.range)
|
||||||
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
|
# this should be in-place operation
|
||||||
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
|
ema_vals[:] = self.alpha * x_quantile + (1 - self.alpha) * ema_vals
|
||||||
offset = self.values[0]
|
scale = torch.clip(ema_vals[1] - ema_vals[0], min=1.0)
|
||||||
|
offset = ema_vals[0]
|
||||||
return offset.detach(), scale.detach()
|
return offset.detach(), scale.detach()
|
||||||
|
|
||||||
|
|
||||||
@ -41,18 +39,13 @@ class WorldModel(nn.Module):
|
|||||||
config.dyn_stoch,
|
config.dyn_stoch,
|
||||||
config.dyn_deter,
|
config.dyn_deter,
|
||||||
config.dyn_hidden,
|
config.dyn_hidden,
|
||||||
config.dyn_input_layers,
|
|
||||||
config.dyn_output_layers,
|
|
||||||
config.dyn_rec_depth,
|
config.dyn_rec_depth,
|
||||||
config.dyn_shared,
|
|
||||||
config.dyn_discrete,
|
config.dyn_discrete,
|
||||||
config.act,
|
config.act,
|
||||||
config.norm,
|
config.norm,
|
||||||
config.dyn_mean_act,
|
config.dyn_mean_act,
|
||||||
config.dyn_std_act,
|
config.dyn_std_act,
|
||||||
config.dyn_temp_post,
|
|
||||||
config.dyn_min_std,
|
config.dyn_min_std,
|
||||||
config.dyn_cell,
|
|
||||||
config.unimix_ratio,
|
config.unimix_ratio,
|
||||||
config.initial,
|
config.initial,
|
||||||
config.num_actions,
|
config.num_actions,
|
||||||
@ -106,10 +99,10 @@ class WorldModel(nn.Module):
|
|||||||
print(
|
print(
|
||||||
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
|
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
|
||||||
)
|
)
|
||||||
|
# other losses are scaled by 1.0.
|
||||||
self._scales = dict(
|
self._scales = dict(
|
||||||
reward=config.reward_head["scale"],
|
reward=config.reward_head["loss_scale"],
|
||||||
cont=config.cont_head["scale"],
|
cont=config.cont_head["loss_scale"],
|
||||||
image=1.0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _train(self, data):
|
def _train(self, data):
|
||||||
@ -148,7 +141,8 @@ class WorldModel(nn.Module):
|
|||||||
assert loss.shape == embed.shape[:2], (name, loss.shape)
|
assert loss.shape == embed.shape[:2], (name, loss.shape)
|
||||||
losses[name] = loss
|
losses[name] = loss
|
||||||
scaled = {
|
scaled = {
|
||||||
key: value * self._scales[key] for key, value in losses.items()
|
key: value * self._scales.get(key, 1.0)
|
||||||
|
for key, value in losses.items()
|
||||||
}
|
}
|
||||||
model_loss = sum(scaled.values()) + kl_loss
|
model_loss = sum(scaled.values()) + kl_loss
|
||||||
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
|
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
|
||||||
@ -217,13 +211,11 @@ class WorldModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ImagBehavior(nn.Module):
|
class ImagBehavior(nn.Module):
|
||||||
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
|
def __init__(self, config, world_model):
|
||||||
super(ImagBehavior, self).__init__()
|
super(ImagBehavior, self).__init__()
|
||||||
self._use_amp = True if config.precision == 16 else False
|
self._use_amp = True if config.precision == 16 else False
|
||||||
self._config = config
|
self._config = config
|
||||||
self._world_model = world_model
|
self._world_model = world_model
|
||||||
self._stop_grad_actor = stop_grad_actor
|
|
||||||
self._reward = reward
|
|
||||||
if config.dyn_discrete:
|
if config.dyn_discrete:
|
||||||
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
||||||
else:
|
else:
|
||||||
@ -284,42 +276,34 @@ class ImagBehavior(nn.Module):
|
|||||||
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
|
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
|
||||||
)
|
)
|
||||||
if self._config.reward_EMA:
|
if self._config.reward_EMA:
|
||||||
|
# register ema_vals to nn.Module for enabling torch.save and torch.load
|
||||||
|
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
|
||||||
self.reward_ema = RewardEMA(device=self._config.device)
|
self.reward_ema = RewardEMA(device=self._config.device)
|
||||||
|
|
||||||
def _train(
|
def _train(
|
||||||
self,
|
self,
|
||||||
start,
|
start,
|
||||||
objective=None,
|
objective,
|
||||||
action=None,
|
|
||||||
reward=None,
|
|
||||||
imagine=None,
|
|
||||||
tape=None,
|
|
||||||
repeats=None,
|
|
||||||
):
|
):
|
||||||
objective = objective or self._reward
|
|
||||||
self._update_slow_target()
|
self._update_slow_target()
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|
||||||
with tools.RequiresGrad(self.actor):
|
with tools.RequiresGrad(self.actor):
|
||||||
with torch.cuda.amp.autocast(self._use_amp):
|
with torch.cuda.amp.autocast(self._use_amp):
|
||||||
imag_feat, imag_state, imag_action = self._imagine(
|
imag_feat, imag_state, imag_action = self._imagine(
|
||||||
start, self.actor, self._config.imag_horizon, repeats
|
start, self.actor, self._config.imag_horizon
|
||||||
)
|
)
|
||||||
reward = objective(imag_feat, imag_state, imag_action)
|
reward = objective(imag_feat, imag_state, imag_action)
|
||||||
actor_ent = self.actor(imag_feat).entropy()
|
actor_ent = self.actor(imag_feat).entropy()
|
||||||
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
|
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
|
||||||
# this target is not scaled
|
# this target is not scaled by ema or sym_log.
|
||||||
# slow is flag to indicate whether slow_target is used for lambda-return
|
|
||||||
target, weights, base = self._compute_target(
|
target, weights, base = self._compute_target(
|
||||||
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
imag_feat, imag_state, reward
|
||||||
)
|
)
|
||||||
actor_loss, mets = self._compute_actor_loss(
|
actor_loss, mets = self._compute_actor_loss(
|
||||||
imag_feat,
|
imag_feat,
|
||||||
imag_state,
|
|
||||||
imag_action,
|
imag_action,
|
||||||
target,
|
target,
|
||||||
actor_ent,
|
|
||||||
state_ent,
|
|
||||||
weights,
|
weights,
|
||||||
base,
|
base,
|
||||||
)
|
)
|
||||||
@ -357,33 +341,27 @@ class ImagBehavior(nn.Module):
|
|||||||
metrics.update(self._value_opt(value_loss, self.value.parameters()))
|
metrics.update(self._value_opt(value_loss, self.value.parameters()))
|
||||||
return imag_feat, imag_state, imag_action, weights, metrics
|
return imag_feat, imag_state, imag_action, weights, metrics
|
||||||
|
|
||||||
def _imagine(self, start, policy, horizon, repeats=None):
|
def _imagine(self, start, policy, horizon):
|
||||||
dynamics = self._world_model.dynamics
|
dynamics = self._world_model.dynamics
|
||||||
if repeats:
|
|
||||||
raise NotImplemented("repeats is not implemented in this version")
|
|
||||||
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
|
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
|
||||||
start = {k: flatten(v) for k, v in start.items()}
|
start = {k: flatten(v) for k, v in start.items()}
|
||||||
|
|
||||||
def step(prev, _):
|
def step(prev, _):
|
||||||
state, _, _ = prev
|
state, _, _ = prev
|
||||||
feat = dynamics.get_feat(state)
|
feat = dynamics.get_feat(state)
|
||||||
inp = feat.detach() if self._stop_grad_actor else feat
|
inp = feat.detach()
|
||||||
action = policy(inp).sample()
|
action = policy(inp).sample()
|
||||||
succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
|
succ = dynamics.img_step(state, action)
|
||||||
return succ, feat, action
|
return succ, feat, action
|
||||||
|
|
||||||
succ, feats, actions = tools.static_scan(
|
succ, feats, actions = tools.static_scan(
|
||||||
step, [torch.arange(horizon)], (start, None, None)
|
step, [torch.arange(horizon)], (start, None, None)
|
||||||
)
|
)
|
||||||
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
|
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
|
||||||
if repeats:
|
|
||||||
raise NotImplemented("repeats is not implemented in this version")
|
|
||||||
|
|
||||||
return feats, states, actions
|
return feats, states, actions
|
||||||
|
|
||||||
def _compute_target(
|
def _compute_target(self, imag_feat, imag_state, reward):
|
||||||
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
|
||||||
):
|
|
||||||
if "cont" in self._world_model.heads:
|
if "cont" in self._world_model.heads:
|
||||||
inp = self._world_model.dynamics.get_feat(imag_state)
|
inp = self._world_model.dynamics.get_feat(imag_state)
|
||||||
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
||||||
@ -406,29 +384,24 @@ class ImagBehavior(nn.Module):
|
|||||||
def _compute_actor_loss(
|
def _compute_actor_loss(
|
||||||
self,
|
self,
|
||||||
imag_feat,
|
imag_feat,
|
||||||
imag_state,
|
|
||||||
imag_action,
|
imag_action,
|
||||||
target,
|
target,
|
||||||
actor_ent,
|
|
||||||
state_ent,
|
|
||||||
weights,
|
weights,
|
||||||
base,
|
base,
|
||||||
):
|
):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
|
inp = imag_feat.detach()
|
||||||
policy = self.actor(inp)
|
policy = self.actor(inp)
|
||||||
actor_ent = policy.entropy()
|
|
||||||
# Q-val for actor is not transformed using symlog
|
# Q-val for actor is not transformed using symlog
|
||||||
target = torch.stack(target, dim=1)
|
target = torch.stack(target, dim=1)
|
||||||
if self._config.reward_EMA:
|
if self._config.reward_EMA:
|
||||||
offset, scale = self.reward_ema(target)
|
offset, scale = self.reward_ema(target, self.ema_vals)
|
||||||
normed_target = (target - offset) / scale
|
normed_target = (target - offset) / scale
|
||||||
normed_base = (base - offset) / scale
|
normed_base = (base - offset) / scale
|
||||||
adv = normed_target - normed_base
|
adv = normed_target - normed_base
|
||||||
metrics.update(tools.tensorstats(normed_target, "normed_target"))
|
metrics.update(tools.tensorstats(normed_target, "normed_target"))
|
||||||
values = self.reward_ema.values
|
metrics["EMA_005"] = to_np(self.ema_vals[0])
|
||||||
metrics["EMA_005"] = to_np(values[0])
|
metrics["EMA_095"] = to_np(self.ema_vals[1])
|
||||||
metrics["EMA_095"] = to_np(values[1])
|
|
||||||
|
|
||||||
if self._config.imag_gradient == "dynamics":
|
if self._config.imag_gradient == "dynamics":
|
||||||
actor_target = adv
|
actor_target = adv
|
||||||
|
110
networks.py
110
networks.py
@ -16,18 +16,13 @@ class RSSM(nn.Module):
|
|||||||
stoch=30,
|
stoch=30,
|
||||||
deter=200,
|
deter=200,
|
||||||
hidden=200,
|
hidden=200,
|
||||||
layers_input=1,
|
|
||||||
layers_output=1,
|
|
||||||
rec_depth=1,
|
rec_depth=1,
|
||||||
shared=False,
|
|
||||||
discrete=False,
|
discrete=False,
|
||||||
act="SiLU",
|
act="SiLU",
|
||||||
norm="LayerNorm",
|
norm=True,
|
||||||
mean_act="none",
|
mean_act="none",
|
||||||
std_act="softplus",
|
std_act="softplus",
|
||||||
temp_post=True,
|
|
||||||
min_std=0.1,
|
min_std=0.1,
|
||||||
cell="gru",
|
|
||||||
unimix_ratio=0.01,
|
unimix_ratio=0.01,
|
||||||
initial="learned",
|
initial="learned",
|
||||||
num_actions=None,
|
num_actions=None,
|
||||||
@ -39,16 +34,11 @@ class RSSM(nn.Module):
|
|||||||
self._deter = deter
|
self._deter = deter
|
||||||
self._hidden = hidden
|
self._hidden = hidden
|
||||||
self._min_std = min_std
|
self._min_std = min_std
|
||||||
self._layers_input = layers_input
|
|
||||||
self._layers_output = layers_output
|
|
||||||
self._rec_depth = rec_depth
|
self._rec_depth = rec_depth
|
||||||
self._shared = shared
|
|
||||||
self._discrete = discrete
|
self._discrete = discrete
|
||||||
act = getattr(torch.nn, act)
|
act = getattr(torch.nn, act)
|
||||||
norm = getattr(torch.nn, norm)
|
|
||||||
self._mean_act = mean_act
|
self._mean_act = mean_act
|
||||||
self._std_act = std_act
|
self._std_act = std_act
|
||||||
self._temp_post = temp_post
|
|
||||||
self._unimix_ratio = unimix_ratio
|
self._unimix_ratio = unimix_ratio
|
||||||
self._initial = initial
|
self._initial = initial
|
||||||
self._num_actions = num_actions
|
self._num_actions = num_actions
|
||||||
@ -60,47 +50,30 @@ class RSSM(nn.Module):
|
|||||||
inp_dim = self._stoch * self._discrete + num_actions
|
inp_dim = self._stoch * self._discrete + num_actions
|
||||||
else:
|
else:
|
||||||
inp_dim = self._stoch + num_actions
|
inp_dim = self._stoch + num_actions
|
||||||
if self._shared:
|
|
||||||
inp_dim += self._embed
|
|
||||||
for i in range(self._layers_input):
|
|
||||||
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||||
inp_layers.append(norm(self._hidden, eps=1e-03))
|
if norm:
|
||||||
|
inp_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
|
||||||
inp_layers.append(act())
|
inp_layers.append(act())
|
||||||
if i == 0:
|
|
||||||
inp_dim = self._hidden
|
|
||||||
self._img_in_layers = nn.Sequential(*inp_layers)
|
self._img_in_layers = nn.Sequential(*inp_layers)
|
||||||
self._img_in_layers.apply(tools.weight_init)
|
self._img_in_layers.apply(tools.weight_init)
|
||||||
if cell == "gru":
|
self._cell = GRUCell(self._hidden, self._deter, norm=norm)
|
||||||
self._cell = GRUCell(self._hidden, self._deter)
|
|
||||||
self._cell.apply(tools.weight_init)
|
self._cell.apply(tools.weight_init)
|
||||||
elif cell == "gru_layer_norm":
|
|
||||||
self._cell = GRUCell(self._hidden, self._deter, norm=True)
|
|
||||||
self._cell.apply(tools.weight_init)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(cell)
|
|
||||||
|
|
||||||
img_out_layers = []
|
img_out_layers = []
|
||||||
inp_dim = self._deter
|
inp_dim = self._deter
|
||||||
for i in range(self._layers_output):
|
|
||||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||||
img_out_layers.append(norm(self._hidden, eps=1e-03))
|
if norm:
|
||||||
|
img_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
|
||||||
img_out_layers.append(act())
|
img_out_layers.append(act())
|
||||||
if i == 0:
|
|
||||||
inp_dim = self._hidden
|
|
||||||
self._img_out_layers = nn.Sequential(*img_out_layers)
|
self._img_out_layers = nn.Sequential(*img_out_layers)
|
||||||
self._img_out_layers.apply(tools.weight_init)
|
self._img_out_layers.apply(tools.weight_init)
|
||||||
|
|
||||||
obs_out_layers = []
|
obs_out_layers = []
|
||||||
if self._temp_post:
|
|
||||||
inp_dim = self._deter + self._embed
|
inp_dim = self._deter + self._embed
|
||||||
else:
|
|
||||||
inp_dim = self._embed
|
|
||||||
for i in range(self._layers_output):
|
|
||||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||||
obs_out_layers.append(norm(self._hidden, eps=1e-03))
|
if norm:
|
||||||
|
obs_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
|
||||||
obs_out_layers.append(act())
|
obs_out_layers.append(act())
|
||||||
if i == 0:
|
|
||||||
inp_dim = self._hidden
|
|
||||||
self._obs_out_layers = nn.Sequential(*obs_out_layers)
|
self._obs_out_layers = nn.Sequential(*obs_out_layers)
|
||||||
self._obs_out_layers.apply(tools.weight_init)
|
self._obs_out_layers.apply(tools.weight_init)
|
||||||
|
|
||||||
@ -200,9 +173,6 @@ class RSSM(nn.Module):
|
|||||||
return dist
|
return dist
|
||||||
|
|
||||||
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
||||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
|
|
||||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
|
||||||
|
|
||||||
# initialize all prev_state
|
# initialize all prev_state
|
||||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||||
prev_state = self.initial(len(is_first))
|
prev_state = self.initial(len(is_first))
|
||||||
@ -223,14 +193,8 @@ class RSSM(nn.Module):
|
|||||||
val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||||
)
|
)
|
||||||
|
|
||||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
prior = self.img_step(prev_state, prev_action)
|
||||||
if self._shared:
|
|
||||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
|
||||||
else:
|
|
||||||
if self._temp_post:
|
|
||||||
x = torch.cat([prior["deter"], embed], -1)
|
x = torch.cat([prior["deter"], embed], -1)
|
||||||
else:
|
|
||||||
x = embed
|
|
||||||
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
|
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
|
||||||
x = self._obs_out_layers(x)
|
x = self._obs_out_layers(x)
|
||||||
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
|
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
|
||||||
@ -242,21 +206,14 @@ class RSSM(nn.Module):
|
|||||||
post = {"stoch": stoch, "deter": prior["deter"], **stats}
|
post = {"stoch": stoch, "deter": prior["deter"], **stats}
|
||||||
return post, prior
|
return post, prior
|
||||||
|
|
||||||
# this is used for making future image
|
def img_step(self, prev_state, prev_action, sample=True):
|
||||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
|
||||||
# (batch, stoch, discrete_num)
|
# (batch, stoch, discrete_num)
|
||||||
prev_stoch = prev_state["stoch"]
|
prev_stoch = prev_state["stoch"]
|
||||||
if self._discrete:
|
if self._discrete:
|
||||||
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
||||||
# (batch, stoch, discrete_num) -> (batch, stoch * discrete_num)
|
# (batch, stoch, discrete_num) -> (batch, stoch * discrete_num)
|
||||||
prev_stoch = prev_stoch.reshape(shape)
|
prev_stoch = prev_stoch.reshape(shape)
|
||||||
if self._shared:
|
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action)
|
||||||
if embed is None:
|
|
||||||
shape = list(prev_action.shape[:-1]) + [self._embed]
|
|
||||||
embed = torch.zeros(shape)
|
|
||||||
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed)
|
|
||||||
x = torch.cat([prev_stoch, prev_action, embed], -1)
|
|
||||||
else:
|
|
||||||
x = torch.cat([prev_stoch, prev_action], -1)
|
x = torch.cat([prev_stoch, prev_action], -1)
|
||||||
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
|
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
|
||||||
x = self._img_in_layers(x)
|
x = self._img_in_layers(x)
|
||||||
@ -508,7 +465,7 @@ class ConvEncoder(nn.Module):
|
|||||||
layers = []
|
layers = []
|
||||||
for i in range(stages):
|
for i in range(stages):
|
||||||
layers.append(
|
layers.append(
|
||||||
Conv2dSame(
|
Conv2dSamePad(
|
||||||
in_channels=in_dim,
|
in_channels=in_dim,
|
||||||
out_channels=out_dim,
|
out_channels=out_dim,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
@ -517,7 +474,7 @@ class ConvEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if norm:
|
if norm:
|
||||||
layers.append(ChLayerNorm(out_dim))
|
layers.append(ImgChLayerNorm(out_dim))
|
||||||
layers.append(act())
|
layers.append(act())
|
||||||
in_dim = out_dim
|
in_dim = out_dim
|
||||||
out_dim *= 2
|
out_dim *= 2
|
||||||
@ -593,7 +550,7 @@ class ConvDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if norm:
|
if norm:
|
||||||
layers.append(ChLayerNorm(out_dim))
|
layers.append(ImgChLayerNorm(out_dim))
|
||||||
if act:
|
if act:
|
||||||
layers.append(act())
|
layers.append(act())
|
||||||
in_dim = out_dim
|
in_dim = out_dim
|
||||||
@ -637,7 +594,7 @@ class MLP(nn.Module):
|
|||||||
layers,
|
layers,
|
||||||
units,
|
units,
|
||||||
act="SiLU",
|
act="SiLU",
|
||||||
norm="LayerNorm",
|
norm=True,
|
||||||
dist="normal",
|
dist="normal",
|
||||||
std=1.0,
|
std=1.0,
|
||||||
min_std=0.1,
|
min_std=0.1,
|
||||||
@ -654,11 +611,9 @@ class MLP(nn.Module):
|
|||||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||||
if self._shape is not None and len(self._shape) == 0:
|
if self._shape is not None and len(self._shape) == 0:
|
||||||
self._shape = (1,)
|
self._shape = (1,)
|
||||||
self._layers = layers
|
|
||||||
act = getattr(torch.nn, act)
|
act = getattr(torch.nn, act)
|
||||||
norm = getattr(torch.nn, norm)
|
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
self._std = std
|
self._std = std if isinstance(std, str) else torch.tensor((std,), device=device)
|
||||||
self._min_std = min_std
|
self._min_std = min_std
|
||||||
self._max_std = max_std
|
self._max_std = max_std
|
||||||
self._absmax = absmax
|
self._absmax = absmax
|
||||||
@ -668,13 +623,16 @@ class MLP(nn.Module):
|
|||||||
self._device = device
|
self._device = device
|
||||||
|
|
||||||
self.layers = nn.Sequential()
|
self.layers = nn.Sequential()
|
||||||
for index in range(self._layers):
|
for i in range(layers):
|
||||||
self.layers.add_module(
|
self.layers.add_module(
|
||||||
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
|
f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False)
|
||||||
)
|
)
|
||||||
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
|
if norm:
|
||||||
self.layers.add_module(f"{name}_act{index}", act())
|
self.layers.add_module(
|
||||||
if index == 0:
|
f"{name}_norm{i}", nn.LayerNorm(units, eps=1e-03)
|
||||||
|
)
|
||||||
|
self.layers.add_module(f"{name}_act{i}", act())
|
||||||
|
if i == 0:
|
||||||
inp_dim = units
|
inp_dim = units
|
||||||
self.layers.apply(tools.weight_init)
|
self.layers.apply(tools.weight_init)
|
||||||
|
|
||||||
@ -783,16 +741,18 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GRUCell(nn.Module):
|
class GRUCell(nn.Module):
|
||||||
def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1):
|
def __init__(self, inp_size, size, norm=True, act=torch.tanh, update_bias=-1):
|
||||||
super(GRUCell, self).__init__()
|
super(GRUCell, self).__init__()
|
||||||
self._inp_size = inp_size
|
self._inp_size = inp_size
|
||||||
self._size = size
|
self._size = size
|
||||||
self._act = act
|
self._act = act
|
||||||
self._norm = norm
|
|
||||||
self._update_bias = update_bias
|
self._update_bias = update_bias
|
||||||
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
|
self.layers = nn.Sequential()
|
||||||
|
self.layers.add_module(
|
||||||
|
"GRU_linear", nn.Linear(inp_size + size, 3 * size, bias=False)
|
||||||
|
)
|
||||||
if norm:
|
if norm:
|
||||||
self._norm = nn.LayerNorm(3 * size, eps=1e-03)
|
self.layers.add_module("GRU_norm", nn.LayerNorm(3 * size, eps=1e-03))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state_size(self):
|
def state_size(self):
|
||||||
@ -800,9 +760,7 @@ class GRUCell(nn.Module):
|
|||||||
|
|
||||||
def forward(self, inputs, state):
|
def forward(self, inputs, state):
|
||||||
state = state[0] # Keras wraps the state in a list.
|
state = state[0] # Keras wraps the state in a list.
|
||||||
parts = self._layer(torch.cat([inputs, state], -1))
|
parts = self.layers(torch.cat([inputs, state], -1))
|
||||||
if self._norm:
|
|
||||||
parts = self._norm(parts)
|
|
||||||
reset, cand, update = torch.split(parts, [self._size] * 3, -1)
|
reset, cand, update = torch.split(parts, [self._size] * 3, -1)
|
||||||
reset = torch.sigmoid(reset)
|
reset = torch.sigmoid(reset)
|
||||||
cand = self._act(reset * cand)
|
cand = self._act(reset * cand)
|
||||||
@ -811,7 +769,7 @@ class GRUCell(nn.Module):
|
|||||||
return output, [output]
|
return output, [output]
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSame(torch.nn.Conv2d):
|
class Conv2dSamePad(torch.nn.Conv2d):
|
||||||
def calc_same_pad(self, i, k, s, d):
|
def calc_same_pad(self, i, k, s, d):
|
||||||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
||||||
|
|
||||||
@ -841,9 +799,9 @@ class Conv2dSame(torch.nn.Conv2d):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class ChLayerNorm(nn.Module):
|
class ImgChLayerNorm(nn.Module):
|
||||||
def __init__(self, ch, eps=1e-03):
|
def __init__(self, ch, eps=1e-03):
|
||||||
super(ChLayerNorm, self).__init__()
|
super(ImgChLayerNorm, self).__init__()
|
||||||
self.norm = torch.nn.LayerNorm(ch, eps=eps)
|
self.norm = torch.nn.LayerNorm(ch, eps=eps)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
31
tools.py
31
tools.py
@ -840,37 +840,6 @@ def static_scan(fn, inputs, start):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Original version
|
|
||||||
# def static_scan2(fn, inputs, start, reverse=False):
|
|
||||||
# last = start
|
|
||||||
# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
|
|
||||||
# indices = range(inputs[0].shape[0])
|
|
||||||
# if reverse:
|
|
||||||
# indices = reversed(indices)
|
|
||||||
# for index in indices:
|
|
||||||
# inp = lambda x: (_input[x] for _input in inputs)
|
|
||||||
# last = fn(last, *inp(index))
|
|
||||||
# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)]
|
|
||||||
# if reverse:
|
|
||||||
# outputs = [list(reversed(x)) for x in outputs]
|
|
||||||
# res = [[]] * len(outputs)
|
|
||||||
# for i in range(len(outputs)):
|
|
||||||
# if type(outputs[i][0]) == type({}):
|
|
||||||
# _res = {}
|
|
||||||
# for key in outputs[i][0].keys():
|
|
||||||
# _res[key] = []
|
|
||||||
# for j in range(len(outputs[i])):
|
|
||||||
# _res[key].append(outputs[i][j][key])
|
|
||||||
# #_res[key] = torch.stack(_res[key], 0)
|
|
||||||
# _res[key] = faster_stack(_res[key], 0)
|
|
||||||
# else:
|
|
||||||
# _res = outputs[i]
|
|
||||||
# #_res = torch.stack(_res, 0)
|
|
||||||
# _res = faster_stack(_res, 0)
|
|
||||||
# res[i] = _res
|
|
||||||
# return res
|
|
||||||
|
|
||||||
|
|
||||||
class Every:
|
class Every:
|
||||||
def __init__(self, every):
|
def __init__(self, every):
|
||||||
self._every = every
|
self._every = every
|
||||||
|
Loading…
x
Reference in New Issue
Block a user