diff --git a/configs.yaml b/configs.yaml index 9879ac6..f5c99b2 100644 --- a/configs.yaml +++ b/configs.yaml @@ -17,7 +17,6 @@ defaults: compile: True precision: 32 debug: False - expl_gifs: False video_pred_log: True # Environment @@ -28,27 +27,21 @@ defaults: time_limit: 1000 grayscale: False prefill: 2500 - eval_noise: 0.0 reward_EMA: True # Model - dyn_cell: 'gru_layer_norm' dyn_hidden: 512 dyn_deter: 512 dyn_stoch: 32 dyn_discrete: 32 - dyn_input_layers: 1 - dyn_output_layers: 1 dyn_rec_depth: 1 - dyn_shared: False dyn_mean_act: 'none' dyn_std_act: 'sigmoid2' dyn_min_std: 0.1 - dyn_temp_post: True grad_heads: ['decoder', 'reward', 'cont'] units: 512 act: 'SiLU' - norm: 'LayerNorm' + norm: True encoder: {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True} decoder: @@ -58,9 +51,9 @@ defaults: critic: {layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0} reward_head: - {layers: 2, dist: 'symlog_disc', scale: 1.0, outscale: 0.0} + {layers: 2, dist: 'symlog_disc', loss_scale: 1.0, outscale: 0.0} cont_head: - {layers: 2, scale: 1.0, outscale: 1.0} + {layers: 2, loss_scale: 1.0, outscale: 1.0} dyn_scale: 0.5 rep_scale: 0.1 kl_free: 1.0 @@ -85,12 +78,7 @@ defaults: imag_horizon: 15 imag_gradient: 'dynamics' imag_gradient_mix: 0.0 - imag_sample: True - expl_amount: 0 eval_state_mean: False - collect_dyn_sample: True - behavior_stop_grad: True - future_entropy: False # Exploration expl_behavior: 'greedy' diff --git a/dreamer.py b/dreamer.py index 4584e36..7384845 100644 --- a/dreamer.py +++ b/dreamer.py @@ -42,9 +42,7 @@ class Dreamer(nn.Module): self._update_count = 0 self._dataset = dataset self._wm = models.WorldModel(obs_space, act_space, self._step, config) - self._task_behavior = models.ImagBehavior( - config, self._wm, config.behavior_stop_grad - ) + self._task_behavior = models.ImagBehavior(config, self._wm) if ( config.compile and os.name != "nt" ): # compilation is not supported on windows @@ -92,9 +90,7 @@ class Dreamer(nn.Module): latent, action = state obs = self._wm.preprocess(obs) embed = self._wm.encoder(obs) - latent, _ = self._wm.dynamics.obs_step( - latent, action, embed, obs["is_first"], self._config.collect_dyn_sample - ) + latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"]) if self._config.eval_state_mean: latent["stoch"] = latent["mean"] feat = self._wm.dynamics.get_feat(latent) @@ -114,21 +110,10 @@ class Dreamer(nn.Module): action = torch.one_hot( torch.argmax(action, dim=-1), self._config.num_actions ) - action = self._exploration(action, training) policy_output = {"action": action, "logprob": logprob} state = (latent, action) return policy_output, state - def _exploration(self, action, training): - amount = self._config.expl_amount if training else self._config.eval_noise - if amount == 0: - return action - if "onehot" in self._config.actor["dist"]: - probs = amount / self._config.num_actions + (1 - amount) * action - return tools.OneHotDist(probs=probs).sample() - else: - return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1) - def _train(self, data): metrics = {} post, context, mets = self._wm._train(data) diff --git a/exploration.py b/exploration.py index 1121c10..98d5231 100644 --- a/exploration.py +++ b/exploration.py @@ -38,7 +38,7 @@ class Random(nn.Module): class Plan2Explore(nn.Module): - def __init__(self, config, world_model, reward=None): + def __init__(self, config, world_model, reward): super(Plan2Explore, self).__init__() self._config = config self._use_amp = True if config.precision == 16 else False diff --git a/models.py b/models.py index ace95b4..6aed023 100644 --- a/models.py +++ b/models.py @@ -1,8 +1,6 @@ import copy import torch from torch import nn -import numpy as np -from PIL import ImageColor, Image, ImageDraw, ImageFont import networks import tools @@ -10,21 +8,21 @@ import tools to_np = lambda x: x.detach().cpu().numpy() -class RewardEMA(object): +class RewardEMA: """running mean and std""" def __init__(self, device, alpha=1e-2): self.device = device - self.values = torch.zeros((2,)).to(device) self.alpha = alpha self.range = torch.tensor([0.05, 0.95]).to(device) - def __call__(self, x): + def __call__(self, x, ema_vals): flat_x = torch.flatten(x.detach()) x_quantile = torch.quantile(input=flat_x, q=self.range) - self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values - scale = torch.clip(self.values[1] - self.values[0], min=1.0) - offset = self.values[0] + # this should be in-place operation + ema_vals[:] = self.alpha * x_quantile + (1 - self.alpha) * ema_vals + scale = torch.clip(ema_vals[1] - ema_vals[0], min=1.0) + offset = ema_vals[0] return offset.detach(), scale.detach() @@ -41,18 +39,13 @@ class WorldModel(nn.Module): config.dyn_stoch, config.dyn_deter, config.dyn_hidden, - config.dyn_input_layers, - config.dyn_output_layers, config.dyn_rec_depth, - config.dyn_shared, config.dyn_discrete, config.act, config.norm, config.dyn_mean_act, config.dyn_std_act, - config.dyn_temp_post, config.dyn_min_std, - config.dyn_cell, config.unimix_ratio, config.initial, config.num_actions, @@ -106,10 +99,10 @@ class WorldModel(nn.Module): print( f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables." ) + # other losses are scaled by 1.0. self._scales = dict( - reward=config.reward_head["scale"], - cont=config.cont_head["scale"], - image=1.0, + reward=config.reward_head["loss_scale"], + cont=config.cont_head["loss_scale"], ) def _train(self, data): @@ -148,7 +141,8 @@ class WorldModel(nn.Module): assert loss.shape == embed.shape[:2], (name, loss.shape) losses[name] = loss scaled = { - key: value * self._scales[key] for key, value in losses.items() + key: value * self._scales.get(key, 1.0) + for key, value in losses.items() } model_loss = sum(scaled.values()) + kl_loss metrics = self._model_opt(torch.mean(model_loss), self.parameters()) @@ -217,13 +211,11 @@ class WorldModel(nn.Module): class ImagBehavior(nn.Module): - def __init__(self, config, world_model, stop_grad_actor=True, reward=None): + def __init__(self, config, world_model): super(ImagBehavior, self).__init__() self._use_amp = True if config.precision == 16 else False self._config = config self._world_model = world_model - self._stop_grad_actor = stop_grad_actor - self._reward = reward if config.dyn_discrete: feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter else: @@ -284,42 +276,34 @@ class ImagBehavior(nn.Module): f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables." ) if self._config.reward_EMA: + # register ema_vals to nn.Module for enabling torch.save and torch.load + self.register_buffer("ema_vals", torch.zeros((2,)).to(self._config.device)) self.reward_ema = RewardEMA(device=self._config.device) def _train( self, start, - objective=None, - action=None, - reward=None, - imagine=None, - tape=None, - repeats=None, + objective, ): - objective = objective or self._reward self._update_slow_target() metrics = {} with tools.RequiresGrad(self.actor): with torch.cuda.amp.autocast(self._use_amp): imag_feat, imag_state, imag_action = self._imagine( - start, self.actor, self._config.imag_horizon, repeats + start, self.actor, self._config.imag_horizon ) reward = objective(imag_feat, imag_state, imag_action) actor_ent = self.actor(imag_feat).entropy() state_ent = self._world_model.dynamics.get_dist(imag_state).entropy() - # this target is not scaled - # slow is flag to indicate whether slow_target is used for lambda-return + # this target is not scaled by ema or sym_log. target, weights, base = self._compute_target( - imag_feat, imag_state, imag_action, reward, actor_ent, state_ent + imag_feat, imag_state, reward ) actor_loss, mets = self._compute_actor_loss( imag_feat, - imag_state, imag_action, target, - actor_ent, - state_ent, weights, base, ) @@ -357,33 +341,27 @@ class ImagBehavior(nn.Module): metrics.update(self._value_opt(value_loss, self.value.parameters())) return imag_feat, imag_state, imag_action, weights, metrics - def _imagine(self, start, policy, horizon, repeats=None): + def _imagine(self, start, policy, horizon): dynamics = self._world_model.dynamics - if repeats: - raise NotImplemented("repeats is not implemented in this version") flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) start = {k: flatten(v) for k, v in start.items()} def step(prev, _): state, _, _ = prev feat = dynamics.get_feat(state) - inp = feat.detach() if self._stop_grad_actor else feat + inp = feat.detach() action = policy(inp).sample() - succ = dynamics.img_step(state, action, sample=self._config.imag_sample) + succ = dynamics.img_step(state, action) return succ, feat, action succ, feats, actions = tools.static_scan( step, [torch.arange(horizon)], (start, None, None) ) states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} - if repeats: - raise NotImplemented("repeats is not implemented in this version") return feats, states, actions - def _compute_target( - self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent - ): + def _compute_target(self, imag_feat, imag_state, reward): if "cont" in self._world_model.heads: inp = self._world_model.dynamics.get_feat(imag_state) discount = self._config.discount * self._world_model.heads["cont"](inp).mean @@ -406,29 +384,24 @@ class ImagBehavior(nn.Module): def _compute_actor_loss( self, imag_feat, - imag_state, imag_action, target, - actor_ent, - state_ent, weights, base, ): metrics = {} - inp = imag_feat.detach() if self._stop_grad_actor else imag_feat + inp = imag_feat.detach() policy = self.actor(inp) - actor_ent = policy.entropy() # Q-val for actor is not transformed using symlog target = torch.stack(target, dim=1) if self._config.reward_EMA: - offset, scale = self.reward_ema(target) + offset, scale = self.reward_ema(target, self.ema_vals) normed_target = (target - offset) / scale normed_base = (base - offset) / scale adv = normed_target - normed_base metrics.update(tools.tensorstats(normed_target, "normed_target")) - values = self.reward_ema.values - metrics["EMA_005"] = to_np(values[0]) - metrics["EMA_095"] = to_np(values[1]) + metrics["EMA_005"] = to_np(self.ema_vals[0]) + metrics["EMA_095"] = to_np(self.ema_vals[1]) if self._config.imag_gradient == "dynamics": actor_target = adv diff --git a/networks.py b/networks.py index 1ad5772..cf9a39c 100644 --- a/networks.py +++ b/networks.py @@ -16,18 +16,13 @@ class RSSM(nn.Module): stoch=30, deter=200, hidden=200, - layers_input=1, - layers_output=1, rec_depth=1, - shared=False, discrete=False, act="SiLU", - norm="LayerNorm", + norm=True, mean_act="none", std_act="softplus", - temp_post=True, min_std=0.1, - cell="gru", unimix_ratio=0.01, initial="learned", num_actions=None, @@ -39,16 +34,11 @@ class RSSM(nn.Module): self._deter = deter self._hidden = hidden self._min_std = min_std - self._layers_input = layers_input - self._layers_output = layers_output self._rec_depth = rec_depth - self._shared = shared self._discrete = discrete act = getattr(torch.nn, act) - norm = getattr(torch.nn, norm) self._mean_act = mean_act self._std_act = std_act - self._temp_post = temp_post self._unimix_ratio = unimix_ratio self._initial = initial self._num_actions = num_actions @@ -60,47 +50,30 @@ class RSSM(nn.Module): inp_dim = self._stoch * self._discrete + num_actions else: inp_dim = self._stoch + num_actions - if self._shared: - inp_dim += self._embed - for i in range(self._layers_input): - inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) - inp_layers.append(norm(self._hidden, eps=1e-03)) - inp_layers.append(act()) - if i == 0: - inp_dim = self._hidden + inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) + if norm: + inp_layers.append(nn.LayerNorm(self._hidden, eps=1e-03)) + inp_layers.append(act()) self._img_in_layers = nn.Sequential(*inp_layers) self._img_in_layers.apply(tools.weight_init) - if cell == "gru": - self._cell = GRUCell(self._hidden, self._deter) - self._cell.apply(tools.weight_init) - elif cell == "gru_layer_norm": - self._cell = GRUCell(self._hidden, self._deter, norm=True) - self._cell.apply(tools.weight_init) - else: - raise NotImplementedError(cell) + self._cell = GRUCell(self._hidden, self._deter, norm=norm) + self._cell.apply(tools.weight_init) img_out_layers = [] inp_dim = self._deter - for i in range(self._layers_output): - img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) - img_out_layers.append(norm(self._hidden, eps=1e-03)) - img_out_layers.append(act()) - if i == 0: - inp_dim = self._hidden + img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) + if norm: + img_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03)) + img_out_layers.append(act()) self._img_out_layers = nn.Sequential(*img_out_layers) self._img_out_layers.apply(tools.weight_init) obs_out_layers = [] - if self._temp_post: - inp_dim = self._deter + self._embed - else: - inp_dim = self._embed - for i in range(self._layers_output): - obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) - obs_out_layers.append(norm(self._hidden, eps=1e-03)) - obs_out_layers.append(act()) - if i == 0: - inp_dim = self._hidden + inp_dim = self._deter + self._embed + obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) + if norm: + obs_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03)) + obs_out_layers.append(act()) self._obs_out_layers = nn.Sequential(*obs_out_layers) self._obs_out_layers.apply(tools.weight_init) @@ -200,9 +173,6 @@ class RSSM(nn.Module): return dist def obs_step(self, prev_state, prev_action, embed, is_first, sample=True): - # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer) - # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs - # initialize all prev_state if prev_state == None or torch.sum(is_first) == len(is_first): prev_state = self.initial(len(is_first)) @@ -223,41 +193,28 @@ class RSSM(nn.Module): val * (1.0 - is_first_r) + init_state[key] * is_first_r ) - prior = self.img_step(prev_state, prev_action, None, sample) - if self._shared: - post = self.img_step(prev_state, prev_action, embed, sample) + prior = self.img_step(prev_state, prev_action) + x = torch.cat([prior["deter"], embed], -1) + # (batch_size, prior_deter + embed) -> (batch_size, hidden) + x = self._obs_out_layers(x) + # (batch_size, hidden) -> (batch_size, stoch, discrete_num) + stats = self._suff_stats_layer("obs", x) + if sample: + stoch = self.get_dist(stats).sample() else: - if self._temp_post: - x = torch.cat([prior["deter"], embed], -1) - else: - x = embed - # (batch_size, prior_deter + embed) -> (batch_size, hidden) - x = self._obs_out_layers(x) - # (batch_size, hidden) -> (batch_size, stoch, discrete_num) - stats = self._suff_stats_layer("obs", x) - if sample: - stoch = self.get_dist(stats).sample() - else: - stoch = self.get_dist(stats).mode() - post = {"stoch": stoch, "deter": prior["deter"], **stats} + stoch = self.get_dist(stats).mode() + post = {"stoch": stoch, "deter": prior["deter"], **stats} return post, prior - # this is used for making future image - def img_step(self, prev_state, prev_action, embed=None, sample=True): + def img_step(self, prev_state, prev_action, sample=True): # (batch, stoch, discrete_num) prev_stoch = prev_state["stoch"] if self._discrete: shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] # (batch, stoch, discrete_num) -> (batch, stoch * discrete_num) prev_stoch = prev_stoch.reshape(shape) - if self._shared: - if embed is None: - shape = list(prev_action.shape[:-1]) + [self._embed] - embed = torch.zeros(shape) - # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed) - x = torch.cat([prev_stoch, prev_action, embed], -1) - else: - x = torch.cat([prev_stoch, prev_action], -1) + # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action) + x = torch.cat([prev_stoch, prev_action], -1) # (batch, stoch * discrete_num + action, embed) -> (batch, hidden) x = self._img_in_layers(x) for _ in range(self._rec_depth): # rec depth is not correctly implemented @@ -508,7 +465,7 @@ class ConvEncoder(nn.Module): layers = [] for i in range(stages): layers.append( - Conv2dSame( + Conv2dSamePad( in_channels=in_dim, out_channels=out_dim, kernel_size=kernel_size, @@ -517,7 +474,7 @@ class ConvEncoder(nn.Module): ) ) if norm: - layers.append(ChLayerNorm(out_dim)) + layers.append(ImgChLayerNorm(out_dim)) layers.append(act()) in_dim = out_dim out_dim *= 2 @@ -593,7 +550,7 @@ class ConvDecoder(nn.Module): ) ) if norm: - layers.append(ChLayerNorm(out_dim)) + layers.append(ImgChLayerNorm(out_dim)) if act: layers.append(act()) in_dim = out_dim @@ -637,7 +594,7 @@ class MLP(nn.Module): layers, units, act="SiLU", - norm="LayerNorm", + norm=True, dist="normal", std=1.0, min_std=0.1, @@ -654,11 +611,9 @@ class MLP(nn.Module): self._shape = (shape,) if isinstance(shape, int) else shape if self._shape is not None and len(self._shape) == 0: self._shape = (1,) - self._layers = layers act = getattr(torch.nn, act) - norm = getattr(torch.nn, norm) self._dist = dist - self._std = std + self._std = std if isinstance(std, str) else torch.tensor((std,), device=device) self._min_std = min_std self._max_std = max_std self._absmax = absmax @@ -668,13 +623,16 @@ class MLP(nn.Module): self._device = device self.layers = nn.Sequential() - for index in range(self._layers): + for i in range(layers): self.layers.add_module( - f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False) + f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False) ) - self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03)) - self.layers.add_module(f"{name}_act{index}", act()) - if index == 0: + if norm: + self.layers.add_module( + f"{name}_norm{i}", nn.LayerNorm(units, eps=1e-03) + ) + self.layers.add_module(f"{name}_act{i}", act()) + if i == 0: inp_dim = units self.layers.apply(tools.weight_init) @@ -783,16 +741,18 @@ class MLP(nn.Module): class GRUCell(nn.Module): - def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1): + def __init__(self, inp_size, size, norm=True, act=torch.tanh, update_bias=-1): super(GRUCell, self).__init__() self._inp_size = inp_size self._size = size self._act = act - self._norm = norm self._update_bias = update_bias - self._layer = nn.Linear(inp_size + size, 3 * size, bias=False) + self.layers = nn.Sequential() + self.layers.add_module( + "GRU_linear", nn.Linear(inp_size + size, 3 * size, bias=False) + ) if norm: - self._norm = nn.LayerNorm(3 * size, eps=1e-03) + self.layers.add_module("GRU_norm", nn.LayerNorm(3 * size, eps=1e-03)) @property def state_size(self): @@ -800,9 +760,7 @@ class GRUCell(nn.Module): def forward(self, inputs, state): state = state[0] # Keras wraps the state in a list. - parts = self._layer(torch.cat([inputs, state], -1)) - if self._norm: - parts = self._norm(parts) + parts = self.layers(torch.cat([inputs, state], -1)) reset, cand, update = torch.split(parts, [self._size] * 3, -1) reset = torch.sigmoid(reset) cand = self._act(reset * cand) @@ -811,7 +769,7 @@ class GRUCell(nn.Module): return output, [output] -class Conv2dSame(torch.nn.Conv2d): +class Conv2dSamePad(torch.nn.Conv2d): def calc_same_pad(self, i, k, s, d): return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) @@ -841,9 +799,9 @@ class Conv2dSame(torch.nn.Conv2d): return ret -class ChLayerNorm(nn.Module): +class ImgChLayerNorm(nn.Module): def __init__(self, ch, eps=1e-03): - super(ChLayerNorm, self).__init__() + super(ImgChLayerNorm, self).__init__() self.norm = torch.nn.LayerNorm(ch, eps=eps) def forward(self, x): diff --git a/tools.py b/tools.py index cb09056..f80384e 100644 --- a/tools.py +++ b/tools.py @@ -840,37 +840,6 @@ def static_scan(fn, inputs, start): return outputs -# Original version -# def static_scan2(fn, inputs, start, reverse=False): -# last = start -# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))] -# indices = range(inputs[0].shape[0]) -# if reverse: -# indices = reversed(indices) -# for index in indices: -# inp = lambda x: (_input[x] for _input in inputs) -# last = fn(last, *inp(index)) -# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)] -# if reverse: -# outputs = [list(reversed(x)) for x in outputs] -# res = [[]] * len(outputs) -# for i in range(len(outputs)): -# if type(outputs[i][0]) == type({}): -# _res = {} -# for key in outputs[i][0].keys(): -# _res[key] = [] -# for j in range(len(outputs[i])): -# _res[key].append(outputs[i][j][key]) -# #_res[key] = torch.stack(_res[key], 0) -# _res[key] = faster_stack(_res[key], 0) -# else: -# _res = outputs[i] -# #_res = torch.stack(_res, 0) -# _res = faster_stack(_res, 0) -# res[i] = _res -# return res - - class Every: def __init__(self, every): self._every = every