erased unused options

This commit is contained in:
NM512 2024-01-05 23:23:09 +09:00
parent a27711ab96
commit 7f66ed5333
6 changed files with 84 additions and 211 deletions

View File

@ -17,7 +17,6 @@ defaults:
compile: True
precision: 32
debug: False
expl_gifs: False
video_pred_log: True
# Environment
@ -28,27 +27,21 @@ defaults:
time_limit: 1000
grayscale: False
prefill: 2500
eval_noise: 0.0
reward_EMA: True
# Model
dyn_cell: 'gru_layer_norm'
dyn_hidden: 512
dyn_deter: 512
dyn_stoch: 32
dyn_discrete: 32
dyn_input_layers: 1
dyn_output_layers: 1
dyn_rec_depth: 1
dyn_shared: False
dyn_mean_act: 'none'
dyn_std_act: 'sigmoid2'
dyn_min_std: 0.1
dyn_temp_post: True
grad_heads: ['decoder', 'reward', 'cont']
units: 512
act: 'SiLU'
norm: 'LayerNorm'
norm: True
encoder:
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
decoder:
@ -58,9 +51,9 @@ defaults:
critic:
{layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0}
reward_head:
{layers: 2, dist: 'symlog_disc', scale: 1.0, outscale: 0.0}
{layers: 2, dist: 'symlog_disc', loss_scale: 1.0, outscale: 0.0}
cont_head:
{layers: 2, scale: 1.0, outscale: 1.0}
{layers: 2, loss_scale: 1.0, outscale: 1.0}
dyn_scale: 0.5
rep_scale: 0.1
kl_free: 1.0
@ -85,12 +78,7 @@ defaults:
imag_horizon: 15
imag_gradient: 'dynamics'
imag_gradient_mix: 0.0
imag_sample: True
expl_amount: 0
eval_state_mean: False
collect_dyn_sample: True
behavior_stop_grad: True
future_entropy: False
# Exploration
expl_behavior: 'greedy'

View File

@ -42,9 +42,7 @@ class Dreamer(nn.Module):
self._update_count = 0
self._dataset = dataset
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad
)
self._task_behavior = models.ImagBehavior(config, self._wm)
if (
config.compile and os.name != "nt"
): # compilation is not supported on windows
@ -92,9 +90,7 @@ class Dreamer(nn.Module):
latent, action = state
obs = self._wm.preprocess(obs)
embed = self._wm.encoder(obs)
latent, _ = self._wm.dynamics.obs_step(
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
)
latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"])
if self._config.eval_state_mean:
latent["stoch"] = latent["mean"]
feat = self._wm.dynamics.get_feat(latent)
@ -114,21 +110,10 @@ class Dreamer(nn.Module):
action = torch.one_hot(
torch.argmax(action, dim=-1), self._config.num_actions
)
action = self._exploration(action, training)
policy_output = {"action": action, "logprob": logprob}
state = (latent, action)
return policy_output, state
def _exploration(self, action, training):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
if "onehot" in self._config.actor["dist"]:
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
else:
return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1)
def _train(self, data):
metrics = {}
post, context, mets = self._wm._train(data)

View File

@ -38,7 +38,7 @@ class Random(nn.Module):
class Plan2Explore(nn.Module):
def __init__(self, config, world_model, reward=None):
def __init__(self, config, world_model, reward):
super(Plan2Explore, self).__init__()
self._config = config
self._use_amp = True if config.precision == 16 else False

View File

@ -1,8 +1,6 @@
import copy
import torch
from torch import nn
import numpy as np
from PIL import ImageColor, Image, ImageDraw, ImageFont
import networks
import tools
@ -10,21 +8,21 @@ import tools
to_np = lambda x: x.detach().cpu().numpy()
class RewardEMA(object):
class RewardEMA:
"""running mean and std"""
def __init__(self, device, alpha=1e-2):
self.device = device
self.values = torch.zeros((2,)).to(device)
self.alpha = alpha
self.range = torch.tensor([0.05, 0.95]).to(device)
def __call__(self, x):
def __call__(self, x, ema_vals):
flat_x = torch.flatten(x.detach())
x_quantile = torch.quantile(input=flat_x, q=self.range)
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
offset = self.values[0]
# this should be in-place operation
ema_vals[:] = self.alpha * x_quantile + (1 - self.alpha) * ema_vals
scale = torch.clip(ema_vals[1] - ema_vals[0], min=1.0)
offset = ema_vals[0]
return offset.detach(), scale.detach()
@ -41,18 +39,13 @@ class WorldModel(nn.Module):
config.dyn_stoch,
config.dyn_deter,
config.dyn_hidden,
config.dyn_input_layers,
config.dyn_output_layers,
config.dyn_rec_depth,
config.dyn_shared,
config.dyn_discrete,
config.act,
config.norm,
config.dyn_mean_act,
config.dyn_std_act,
config.dyn_temp_post,
config.dyn_min_std,
config.dyn_cell,
config.unimix_ratio,
config.initial,
config.num_actions,
@ -106,10 +99,10 @@ class WorldModel(nn.Module):
print(
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
)
# other losses are scaled by 1.0.
self._scales = dict(
reward=config.reward_head["scale"],
cont=config.cont_head["scale"],
image=1.0,
reward=config.reward_head["loss_scale"],
cont=config.cont_head["loss_scale"],
)
def _train(self, data):
@ -148,7 +141,8 @@ class WorldModel(nn.Module):
assert loss.shape == embed.shape[:2], (name, loss.shape)
losses[name] = loss
scaled = {
key: value * self._scales[key] for key, value in losses.items()
key: value * self._scales.get(key, 1.0)
for key, value in losses.items()
}
model_loss = sum(scaled.values()) + kl_loss
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
@ -217,13 +211,11 @@ class WorldModel(nn.Module):
class ImagBehavior(nn.Module):
def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
def __init__(self, config, world_model):
super(ImagBehavior, self).__init__()
self._use_amp = True if config.precision == 16 else False
self._config = config
self._world_model = world_model
self._stop_grad_actor = stop_grad_actor
self._reward = reward
if config.dyn_discrete:
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
else:
@ -284,42 +276,34 @@ class ImagBehavior(nn.Module):
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
)
if self._config.reward_EMA:
# register ema_vals to nn.Module for enabling torch.save and torch.load
self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device))
self.reward_ema = RewardEMA(device=self._config.device)
def _train(
self,
start,
objective=None,
action=None,
reward=None,
imagine=None,
tape=None,
repeats=None,
objective,
):
objective = objective or self._reward
self._update_slow_target()
metrics = {}
with tools.RequiresGrad(self.actor):
with torch.cuda.amp.autocast(self._use_amp):
imag_feat, imag_state, imag_action = self._imagine(
start, self.actor, self._config.imag_horizon, repeats
start, self.actor, self._config.imag_horizon
)
reward = objective(imag_feat, imag_state, imag_action)
actor_ent = self.actor(imag_feat).entropy()
state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
# this target is not scaled
# slow is flag to indicate whether slow_target is used for lambda-return
# this target is not scaled by ema or sym_log.
target, weights, base = self._compute_target(
imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
imag_feat, imag_state, reward
)
actor_loss, mets = self._compute_actor_loss(
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
)
@ -357,33 +341,27 @@ class ImagBehavior(nn.Module):
metrics.update(self._value_opt(value_loss, self.value.parameters()))
return imag_feat, imag_state, imag_action, weights, metrics
def _imagine(self, start, policy, horizon, repeats=None):
def _imagine(self, start, policy, horizon):
dynamics = self._world_model.dynamics
if repeats:
raise NotImplemented("repeats is not implemented in this version")
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = dynamics.get_feat(state)
inp = feat.detach() if self._stop_grad_actor else feat
inp = feat.detach()
action = policy(inp).sample()
succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
succ = dynamics.img_step(state, action)
return succ, feat, action
succ, feats, actions = tools.static_scan(
step, [torch.arange(horizon)], (start, None, None)
)
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
if repeats:
raise NotImplemented("repeats is not implemented in this version")
return feats, states, actions
def _compute_target(
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
):
def _compute_target(self, imag_feat, imag_state, reward):
if "cont" in self._world_model.heads:
inp = self._world_model.dynamics.get_feat(imag_state)
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
@ -406,29 +384,24 @@ class ImagBehavior(nn.Module):
def _compute_actor_loss(
self,
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
):
metrics = {}
inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
inp = imag_feat.detach()
policy = self.actor(inp)
actor_ent = policy.entropy()
# Q-val for actor is not transformed using symlog
target = torch.stack(target, dim=1)
if self._config.reward_EMA:
offset, scale = self.reward_ema(target)
offset, scale = self.reward_ema(target, self.ema_vals)
normed_target = (target - offset) / scale
normed_base = (base - offset) / scale
adv = normed_target - normed_base
metrics.update(tools.tensorstats(normed_target, "normed_target"))
values = self.reward_ema.values
metrics["EMA_005"] = to_np(values[0])
metrics["EMA_095"] = to_np(values[1])
metrics["EMA_005"] = to_np(self.ema_vals[0])
metrics["EMA_095"] = to_np(self.ema_vals[1])
if self._config.imag_gradient == "dynamics":
actor_target = adv

View File

@ -16,18 +16,13 @@ class RSSM(nn.Module):
stoch=30,
deter=200,
hidden=200,
layers_input=1,
layers_output=1,
rec_depth=1,
shared=False,
discrete=False,
act="SiLU",
norm="LayerNorm",
norm=True,
mean_act="none",
std_act="softplus",
temp_post=True,
min_std=0.1,
cell="gru",
unimix_ratio=0.01,
initial="learned",
num_actions=None,
@ -39,16 +34,11 @@ class RSSM(nn.Module):
self._deter = deter
self._hidden = hidden
self._min_std = min_std
self._layers_input = layers_input
self._layers_output = layers_output
self._rec_depth = rec_depth
self._shared = shared
self._discrete = discrete
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._mean_act = mean_act
self._std_act = std_act
self._temp_post = temp_post
self._unimix_ratio = unimix_ratio
self._initial = initial
self._num_actions = num_actions
@ -60,47 +50,30 @@ class RSSM(nn.Module):
inp_dim = self._stoch * self._discrete + num_actions
else:
inp_dim = self._stoch + num_actions
if self._shared:
inp_dim += self._embed
for i in range(self._layers_input):
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
inp_layers.append(norm(self._hidden, eps=1e-03))
inp_layers.append(act())
if i == 0:
inp_dim = self._hidden
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
if norm:
inp_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
inp_layers.append(act())
self._img_in_layers = nn.Sequential(*inp_layers)
self._img_in_layers.apply(tools.weight_init)
if cell == "gru":
self._cell = GRUCell(self._hidden, self._deter)
self._cell.apply(tools.weight_init)
elif cell == "gru_layer_norm":
self._cell = GRUCell(self._hidden, self._deter, norm=True)
self._cell.apply(tools.weight_init)
else:
raise NotImplementedError(cell)
self._cell = GRUCell(self._hidden, self._deter, norm=norm)
self._cell.apply(tools.weight_init)
img_out_layers = []
inp_dim = self._deter
for i in range(self._layers_output):
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
img_out_layers.append(norm(self._hidden, eps=1e-03))
img_out_layers.append(act())
if i == 0:
inp_dim = self._hidden
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
if norm:
img_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
img_out_layers.append(act())
self._img_out_layers = nn.Sequential(*img_out_layers)
self._img_out_layers.apply(tools.weight_init)
obs_out_layers = []
if self._temp_post:
inp_dim = self._deter + self._embed
else:
inp_dim = self._embed
for i in range(self._layers_output):
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
obs_out_layers.append(norm(self._hidden, eps=1e-03))
obs_out_layers.append(act())
if i == 0:
inp_dim = self._hidden
inp_dim = self._deter + self._embed
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
if norm:
obs_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
obs_out_layers.append(act())
self._obs_out_layers = nn.Sequential(*obs_out_layers)
self._obs_out_layers.apply(tools.weight_init)
@ -200,9 +173,6 @@ class RSSM(nn.Module):
return dist
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
# initialize all prev_state
if prev_state == None or torch.sum(is_first) == len(is_first):
prev_state = self.initial(len(is_first))
@ -223,41 +193,28 @@ class RSSM(nn.Module):
val * (1.0 - is_first_r) + init_state[key] * is_first_r
)
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
prior = self.img_step(prev_state, prev_action)
x = torch.cat([prior["deter"], embed], -1)
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
x = self._obs_out_layers(x)
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
stats = self._suff_stats_layer("obs", x)
if sample:
stoch = self.get_dist(stats).sample()
else:
if self._temp_post:
x = torch.cat([prior["deter"], embed], -1)
else:
x = embed
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
x = self._obs_out_layers(x)
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
stats = self._suff_stats_layer("obs", x)
if sample:
stoch = self.get_dist(stats).sample()
else:
stoch = self.get_dist(stats).mode()
post = {"stoch": stoch, "deter": prior["deter"], **stats}
stoch = self.get_dist(stats).mode()
post = {"stoch": stoch, "deter": prior["deter"], **stats}
return post, prior
# this is used for making future image
def img_step(self, prev_state, prev_action, embed=None, sample=True):
def img_step(self, prev_state, prev_action, sample=True):
# (batch, stoch, discrete_num)
prev_stoch = prev_state["stoch"]
if self._discrete:
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
# (batch, stoch, discrete_num) -> (batch, stoch * discrete_num)
prev_stoch = prev_stoch.reshape(shape)
if self._shared:
if embed is None:
shape = list(prev_action.shape[:-1]) + [self._embed]
embed = torch.zeros(shape)
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed)
x = torch.cat([prev_stoch, prev_action, embed], -1)
else:
x = torch.cat([prev_stoch, prev_action], -1)
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action)
x = torch.cat([prev_stoch, prev_action], -1)
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
x = self._img_in_layers(x)
for _ in range(self._rec_depth): # rec depth is not correctly implemented
@ -508,7 +465,7 @@ class ConvEncoder(nn.Module):
layers = []
for i in range(stages):
layers.append(
Conv2dSame(
Conv2dSamePad(
in_channels=in_dim,
out_channels=out_dim,
kernel_size=kernel_size,
@ -517,7 +474,7 @@ class ConvEncoder(nn.Module):
)
)
if norm:
layers.append(ChLayerNorm(out_dim))
layers.append(ImgChLayerNorm(out_dim))
layers.append(act())
in_dim = out_dim
out_dim *= 2
@ -593,7 +550,7 @@ class ConvDecoder(nn.Module):
)
)
if norm:
layers.append(ChLayerNorm(out_dim))
layers.append(ImgChLayerNorm(out_dim))
if act:
layers.append(act())
in_dim = out_dim
@ -637,7 +594,7 @@ class MLP(nn.Module):
layers,
units,
act="SiLU",
norm="LayerNorm",
norm=True,
dist="normal",
std=1.0,
min_std=0.1,
@ -654,11 +611,9 @@ class MLP(nn.Module):
self._shape = (shape,) if isinstance(shape, int) else shape
if self._shape is not None and len(self._shape) == 0:
self._shape = (1,)
self._layers = layers
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._dist = dist
self._std = std
self._std = std if isinstance(std, str) else torch.tensor((std,), device=device)
self._min_std = min_std
self._max_std = max_std
self._absmax = absmax
@ -668,13 +623,16 @@ class MLP(nn.Module):
self._device = device
self.layers = nn.Sequential()
for index in range(self._layers):
for i in range(layers):
self.layers.add_module(
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False)
)
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
self.layers.add_module(f"{name}_act{index}", act())
if index == 0:
if norm:
self.layers.add_module(
f"{name}_norm{i}", nn.LayerNorm(units, eps=1e-03)
)
self.layers.add_module(f"{name}_act{i}", act())
if i == 0:
inp_dim = units
self.layers.apply(tools.weight_init)
@ -783,16 +741,18 @@ class MLP(nn.Module):
class GRUCell(nn.Module):
def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1):
def __init__(self, inp_size, size, norm=True, act=torch.tanh, update_bias=-1):
super(GRUCell, self).__init__()
self._inp_size = inp_size
self._size = size
self._act = act
self._norm = norm
self._update_bias = update_bias
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
self.layers = nn.Sequential()
self.layers.add_module(
"GRU_linear", nn.Linear(inp_size + size, 3 * size, bias=False)
)
if norm:
self._norm = nn.LayerNorm(3 * size, eps=1e-03)
self.layers.add_module("GRU_norm", nn.LayerNorm(3 * size, eps=1e-03))
@property
def state_size(self):
@ -800,9 +760,7 @@ class GRUCell(nn.Module):
def forward(self, inputs, state):
state = state[0] # Keras wraps the state in a list.
parts = self._layer(torch.cat([inputs, state], -1))
if self._norm:
parts = self._norm(parts)
parts = self.layers(torch.cat([inputs, state], -1))
reset, cand, update = torch.split(parts, [self._size] * 3, -1)
reset = torch.sigmoid(reset)
cand = self._act(reset * cand)
@ -811,7 +769,7 @@ class GRUCell(nn.Module):
return output, [output]
class Conv2dSame(torch.nn.Conv2d):
class Conv2dSamePad(torch.nn.Conv2d):
def calc_same_pad(self, i, k, s, d):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
@ -841,9 +799,9 @@ class Conv2dSame(torch.nn.Conv2d):
return ret
class ChLayerNorm(nn.Module):
class ImgChLayerNorm(nn.Module):
def __init__(self, ch, eps=1e-03):
super(ChLayerNorm, self).__init__()
super(ImgChLayerNorm, self).__init__()
self.norm = torch.nn.LayerNorm(ch, eps=eps)
def forward(self, x):

View File

@ -840,37 +840,6 @@ def static_scan(fn, inputs, start):
return outputs
# Original version
# def static_scan2(fn, inputs, start, reverse=False):
# last = start
# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))]
# indices = range(inputs[0].shape[0])
# if reverse:
# indices = reversed(indices)
# for index in indices:
# inp = lambda x: (_input[x] for _input in inputs)
# last = fn(last, *inp(index))
# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)]
# if reverse:
# outputs = [list(reversed(x)) for x in outputs]
# res = [[]] * len(outputs)
# for i in range(len(outputs)):
# if type(outputs[i][0]) == type({}):
# _res = {}
# for key in outputs[i][0].keys():
# _res[key] = []
# for j in range(len(outputs[i])):
# _res[key].append(outputs[i][j][key])
# #_res[key] = torch.stack(_res[key], 0)
# _res[key] = faster_stack(_res[key], 0)
# else:
# _res = outputs[i]
# #_res = torch.stack(_res, 0)
# _res = faster_stack(_res, 0)
# res[i] = _res
# return res
class Every:
def __init__(self, every):
self._every = every