From b984e69b6e507774cc85a182b6636daa8095df4c Mon Sep 17 00:00:00 2001 From: NM512 Date: Sun, 14 May 2023 23:38:46 +0900 Subject: [PATCH] added state input capability --- README.md | 14 ++- configs.yaml | 29 +++-- dreamer.py | 39 +++++-- envs/dmc.py | 8 +- exploration.py | 2 +- models.py | 63 ++++------ networks.py | 304 ++++++++++++++++++++++++++++++++++++++----------- tools.py | 52 +++++++-- 8 files changed, 369 insertions(+), 142 deletions(-) diff --git a/README.md b/README.md index 65e9984..98f4e3b 100644 --- a/README.md +++ b/README.md @@ -7,20 +7,24 @@ Get dependencies: ``` pip install -r requirements.txt ``` -Train the agent on Walker Walk in Vision DMC: +Train the agent on Walker Walk in DMC Vision: ``` -python3 dreamer.py --configs defaults --task dmc_walker_walk --logdir ~/dreamerv3-torch/logdir/dmc_walker_walk +python3 dreamer.py --configs dmc_vision --task dmc_walker_walk --logdir ./logdir/dmc_walker_walk +``` +Train the agent on Walker Walk in DMC Proprio: +``` +python3 dreamer.py --configs dmc_proprio --task dmc_walker_walk --logdir ./logdir/dmc_walker_walk ``` Train the agent on Alien in Atari 100K: ``` -python3 dreamer.py --configs defaults atari100k --task atari_alien --logdir ~/dreamerv3-torch/logdir/atari_alien +python3 dreamer.py --configs atari100k --task atari_alien --logdir ./logdir/atari_alien ``` Monitor results: ``` tensorboard --logdir ~/dreamerv3-torch/logdir ``` -## Evaluation Results +## Results More results will be added in the future. ![dmc_vision](https://user-images.githubusercontent.com/70328564/236276650-ae706f29-4c14-4ed3-9b61-1829a1fdedae.png) @@ -30,7 +34,7 @@ More results will be added in the future. - [x] Modify implementation details based on the author's implementation - [x] Evaluate on DMC vision - [x] Evaluate on Atari 100K -- [ ] Add state input capability +- [x] Add state input capability - [ ] Evaluate on DMC Proprio - [ ] etc. diff --git a/configs.yaml b/configs.yaml index f4f13f3..76a270a 100644 --- a/configs.yaml +++ b/configs.yaml @@ -1,4 +1,3 @@ -# defaults is for Vision DMC defaults: logdir: null @@ -17,6 +16,7 @@ defaults: precision: 16 debug: False expl_gifs: False + video_pred_log: True # Environment task: 'dmc_walker_walk' @@ -43,7 +43,7 @@ defaults: dyn_std_act: 'sigmoid2' dyn_min_std: 0.1 dyn_temp_post: True - grad_heads: ['image', 'reward', 'cont'] + grad_heads: ['decoder', 'reward', 'cont'] units: 512 reward_layers: 2 cont_layers: 2 @@ -51,11 +51,12 @@ defaults: actor_layers: 2 act: 'SiLU' norm: 'LayerNorm' - cnn_depth: 32 - encoder_kernels: [4, 4, 4, 4] - decoder_kernels: [4, 4, 4, 4] - value_head: 'twohot_symlog' - reward_head: 'twohot_symlog' + encoder: + {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, symlog_inputs: True} + decoder: + {mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, cnn_kernels: [4, 4, 4, 4], mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse,} + value_head: 'symlog_disc' + reward_head: 'symlog_disc' dyn_scale: '0.5' rep_scale: '0.1' kl_free: '1.0' @@ -119,6 +120,20 @@ defaults: disag_units: 400 disag_action_cond: False +dmc_vision: + steps: 1e6 + train_ratio: 512 + video_pred_log: true + encoder: {mlp_keys: '$^', cnn_keys: 'image'} + decoder: {mlp_keys: '$^', cnn_keys: 'image'} + +dmc_proprio: + steps: 5e5 + train_ratio: 512 + video_pred_log: false + encoder: {mlp_keys: '.*', cnn_keys: '$^'} + decoder: {mlp_keys: '.*', cnn_keys: '$^'} + atari100k: steps: 4e5 action_repeat: 4 diff --git a/dreamer.py b/dreamer.py index e4abb6f..5a23d88 100644 --- a/dreamer.py +++ b/dreamer.py @@ -27,7 +27,7 @@ to_np = lambda x: x.detach().cpu().numpy() class Dreamer(nn.Module): - def __init__(self, config, logger, dataset): + def __init__(self, obs_space, act_space, config, logger, dataset): super(Dreamer, self).__init__() self._config = config self._logger = logger @@ -51,7 +51,7 @@ class Dreamer(nn.Module): x, self._step ) self._dataset = dataset - self._wm = models.WorldModel(self._step, config) + self._wm = models.WorldModel(obs_space, act_space, self._step, config) self._task_behavior = models.ImagBehavior( config, self._wm, config.behavior_stop_grad ) @@ -90,8 +90,9 @@ class Dreamer(nn.Module): for name, values in self._metrics.items(): self._logger.scalar(name, float(np.mean(values))) self._metrics[name] = [] - openl = self._wm.video_pred(next(self._dataset)) - self._logger.video("train_openl", to_np(openl)) + if self._config.video_pred_log: + openl = self._wm.video_pred(next(self._dataset)) + self._logger.video("train_openl", to_np(openl)) self._logger.write(fps=True) policy_output, state = self._policy(obs, state, training) @@ -296,8 +297,6 @@ def main(config): config.eval_every //= config.action_repeat config.log_every //= config.action_repeat config.time_limit //= config.action_repeat - config.act = getattr(torch.nn, config.act) - config.norm = getattr(torch.nn, config.norm) print("Logdir", logdir) logdir.mkdir(parents=True, exist_ok=True) @@ -350,7 +349,13 @@ def main(config): print("Simulate agent.") train_dataset = make_dataset(train_eps, config) eval_dataset = make_dataset(eval_eps, config) - agent = Dreamer(config, logger, train_dataset).to(config.device) + agent = Dreamer( + train_envs[0].observation_space, + train_envs[0].action_space, + config, + logger, + train_dataset, + ).to(config.device) agent.requires_grad_(requires_grad=False) if (logdir / "latest_model.pt").exists(): agent.load_state_dict(torch.load(logdir / "latest_model.pt")) @@ -362,8 +367,9 @@ def main(config): print("Start evaluation.") eval_policy = functools.partial(agent, training=False) tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num) - video_pred = agent._wm.video_pred(next(eval_dataset)) - logger.video("eval_openl", to_np(video_pred)) + if config.video_pred_log: + video_pred = agent._wm.video_pred(next(eval_dataset)) + logger.video("eval_openl", to_np(video_pred)) print("Start training.") state = tools.simulate(agent, train_envs, config.eval_every, state=state) torch.save(agent.state_dict(), logdir / "latest_model.pt") @@ -376,14 +382,23 @@ def main(config): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--configs", nargs="+", required=True) + parser.add_argument("--configs", nargs="+") args, remaining = parser.parse_known_args() configs = yaml.safe_load( (pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text() ) + + def recursive_update(base, update): + for key, value in update.items(): + if isinstance(value, dict) and key in base: + recursive_update(base[key], value) + else: + base[key] = value + + name_list = ["defaults", *args.configs] if args.configs else ["defaults"] defaults = {} - for name in args.configs: - defaults.update(configs[name]) + for name in name_list: + recursive_update(defaults, configs[name]) parser = argparse.ArgumentParser() for key, value in sorted(defaults.items(), key=lambda x: x[0]): arg_type = tools.args_type(value) diff --git a/envs/dmc.py b/envs/dmc.py index a39b29e..c88cce8 100644 --- a/envs/dmc.py +++ b/envs/dmc.py @@ -24,7 +24,11 @@ class DeepMindControl: def observation_space(self): spaces = {} for key, value in self._env.observation_spec().items(): - spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32) + if len(value.shape) == 0: + shape = (1,) + else: + shape = value.shape + spaces[key] = gym.spaces.Box(-np.inf, np.inf, shape, dtype=np.float32) spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8) return gym.spaces.Dict(spaces) @@ -42,6 +46,7 @@ class DeepMindControl: if time_step.last(): break obs = dict(time_step.observation) + obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()} obs["image"] = self.render() # There is no terminal state in DMC obs["is_terminal"] = False if time_step.first() else time_step.discount == 0 @@ -53,6 +58,7 @@ class DeepMindControl: def reset(self): time_step = self._env.reset() obs = dict(time_step.observation) + obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()} obs["image"] = self.render() obs["is_terminal"] = False if time_step.first() else time_step.discount == 0 obs["is_first"] = time_step.first() diff --git a/exploration.py b/exploration.py index 40985cd..f57877c 100644 --- a/exploration.py +++ b/exploration.py @@ -52,7 +52,7 @@ class Plan2Explore(nn.Module): act=config.act, ) self._networks = nn.ModuleList( - [networks.DenseHead(**kw) for _ in range(config.disag_models)] + [networks.MLP(**kw) for _ in range(config.disag_models)] ) kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) self._model_opt = tools.Optimizer( diff --git a/models.py b/models.py index 7120234..163765c 100644 --- a/models.py +++ b/models.py @@ -29,26 +29,14 @@ class RewardEMA(object): class WorldModel(nn.Module): - def __init__(self, step, config): + def __init__(self, obs_space, act_space, step, config): super(WorldModel, self).__init__() self._step = step self._use_amp = True if config.precision == 16 else False self._config = config - self.encoder = networks.ConvEncoder( - config.grayscale, - config.cnn_depth, - config.act, - config.norm, - config.encoder_kernels, - ) - if config.size[0] == 64 and config.size[1] == 64: - embed_size = ( - (64 // 2 ** (len(config.encoder_kernels))) ** 2 - * config.cnn_depth - * 2 ** (len(config.encoder_kernels) - 1) - ) - else: - raise NotImplemented(f"{config.size} is not applicable now") + shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()} + self.encoder = networks.MultiEncoder(shapes, **config.encoder) + embed_size = self.encoder.outdim self.dynamics = networks.RSSM( config.dyn_stoch, config.dyn_deter, @@ -72,22 +60,15 @@ class WorldModel(nn.Module): config.device, ) self.heads = nn.ModuleDict() - channels = 1 if config.grayscale else 3 - shape = (channels,) + config.size if config.dyn_discrete: feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter else: feat_size = config.dyn_stoch + config.dyn_deter - self.heads["image"] = networks.ConvDecoder( - feat_size, # pytorch version - config.cnn_depth, - config.act, - config.norm, - shape, - config.decoder_kernels, + self.heads["decoder"] = networks.MultiDecoder( + feat_size, shapes, **config.decoder ) - if config.reward_head == "twohot_symlog": - self.heads["reward"] = networks.DenseHead( + if config.reward_head == "symlog_disc": + self.heads["reward"] = networks.MLP( feat_size, # pytorch version (255,), config.reward_layers, @@ -99,7 +80,7 @@ class WorldModel(nn.Module): device=config.device, ) else: - self.heads["reward"] = networks.DenseHead( + self.heads["reward"] = networks.MLP( feat_size, # pytorch version [], config.reward_layers, @@ -110,7 +91,7 @@ class WorldModel(nn.Module): outscale=0.0, device=config.device, ) - self.heads["cont"] = networks.DenseHead( + self.heads["cont"] = networks.MLP( feat_size, # pytorch version [], config.cont_layers, @@ -153,15 +134,19 @@ class WorldModel(nn.Module): kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss( post, prior, kl_free, dyn_scale, rep_scale ) - losses = {} - likes = {} + preds = {} for name, head in self.heads.items(): grad_head = name in self._config.grad_heads feat = self.dynamics.get_feat(post) feat = feat if grad_head else feat.detach() pred = head(feat) + if type(pred) is dict: + preds.update(pred) + else: + preds[name] = pred + losses = {} + for name, pred in preds.items(): like = pred.log_prob(data[name]) - likes[name] = like losses[name] = -torch.mean(like) * self._scales.get(name, 1.0) model_loss = sum(losses.values()) + kl_loss metrics = self._model_opt(model_loss, self.parameters()) @@ -213,11 +198,13 @@ class WorldModel(nn.Module): states, _ = self.dynamics.observe( embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5] ) - recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6] + recon = self.heads["decoder"](self.dynamics.get_feat(states))["image"].mode()[ + :6 + ] reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6] init = {k: v[:, -1] for k, v in states.items()} prior = self.dynamics.imagine(data["action"][:6, 5:], init) - openl = self.heads["image"](self.dynamics.get_feat(prior)).mode() + openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode() reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode() # observed image is given until 5 steps model = torch.cat([recon[:, :5], openl], 1) @@ -254,9 +241,9 @@ class ImagBehavior(nn.Module): config.actor_temp, outscale=1.0, unimix_ratio=config.action_unimix_ratio, - ) # action_dist -> action_disc? - if config.value_head == "twohot_symlog": - self.value = networks.DenseHead( + ) + if config.value_head == "symlog_disc": + self.value = networks.MLP( feat_size, # pytorch version (255,), config.value_layers, @@ -268,7 +255,7 @@ class ImagBehavior(nn.Module): device=config.device, ) else: - self.value = networks.DenseHead( + self.value = networks.MLP( feat_size, # pytorch version [], config.value_layers, diff --git a/networks.py b/networks.py index f52d1a8..8829cde 100644 --- a/networks.py +++ b/networks.py @@ -1,5 +1,6 @@ import math import numpy as np +import re import torch from torch import nn @@ -20,8 +21,8 @@ class RSSM(nn.Module): rec_depth=1, shared=False, discrete=False, - act=nn.ELU, - norm=nn.LayerNorm, + act="SiLU", + norm="LayerNorm", mean_act="none", std_act="softplus", temp_post=True, @@ -43,8 +44,8 @@ class RSSM(nn.Module): self._rec_depth = rec_depth self._shared = shared self._discrete = discrete - self._act = act - self._norm = norm + act = getattr(torch.nn, act) + norm = getattr(torch.nn, norm) self._mean_act = mean_act self._std_act = std_act self._temp_post = temp_post @@ -62,8 +63,8 @@ class RSSM(nn.Module): inp_dim += self._embed for i in range(self._layers_input): inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) - inp_layers.append(self._norm(self._hidden, eps=1e-03)) - inp_layers.append(self._act()) + inp_layers.append(norm(self._hidden, eps=1e-03)) + inp_layers.append(act()) if i == 0: inp_dim = self._hidden self._inp_layers = nn.Sequential(*inp_layers) @@ -82,8 +83,8 @@ class RSSM(nn.Module): inp_dim = self._deter for i in range(self._layers_output): img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) - img_out_layers.append(self._norm(self._hidden, eps=1e-03)) - img_out_layers.append(self._act()) + img_out_layers.append(norm(self._hidden, eps=1e-03)) + img_out_layers.append(act()) if i == 0: inp_dim = self._hidden self._img_out_layers = nn.Sequential(*img_out_layers) @@ -96,8 +97,8 @@ class RSSM(nn.Module): inp_dim = self._embed for i in range(self._layers_output): obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) - obs_out_layers.append(self._norm(self._hidden, eps=1e-03)) - obs_out_layers.append(self._act()) + obs_out_layers.append(norm(self._hidden, eps=1e-03)) + obs_out_layers.append(act()) if i == 0: inp_dim = self._hidden self._obs_out_layers = nn.Sequential(*obs_out_layers) @@ -327,28 +328,156 @@ class RSSM(nn.Module): return loss, value, dyn_loss, rep_loss -class ConvEncoder(nn.Module): +class MultiEncoder(nn.Module): def __init__( self, - grayscale=False, - depth=32, - act=nn.ELU, - norm=nn.LayerNorm, - kernels=(3, 3, 3, 3), + shapes, + mlp_keys, + cnn_keys, + act, + norm, + cnn_depth, + cnn_kernels, + mlp_layers, + mlp_units, + symlog_inputs, + ): + super(MultiEncoder, self).__init__() + self.cnn_shapes = { + k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k) + } + self.mlp_shapes = { + k: v + for k, v in shapes.items() + if len(v) in (1, 2) and re.match(mlp_keys, k) + } + print("Encoder CNN shapes:", self.cnn_shapes) + print("Encoder MLP shapes:", self.mlp_shapes) + + self.outdim = 0 + if self.cnn_shapes: + input_ch = sum([v[-1] for v in self.cnn_shapes.values()]) + self._cnn = ConvEncoder(input_ch, cnn_depth, act, norm, cnn_kernels) + self.outdim += self._cnn.outdim + if self.mlp_shapes: + input_size = sum([sum(v) for v in self.mlp_shapes.values()]) + self._mlp = MLP( + input_size, + None, + mlp_layers, + mlp_units, + act, + norm, + symlog_inputs=symlog_inputs, + ) + self.outdim += mlp_units + + def forward(self, obs): + outputs = [] + if self.cnn_shapes: + inputs = torch.cat([obs[k] for k in self.cnn_shapes], -1) + outputs.append(self._cnn(inputs)) + if self.mlp_shapes: + inputs = torch.cat([obs[k] for k in self.mlp_shapes], -1) + outputs.append(self._mlp(inputs)) + outputs = torch.cat(outputs, -1) + return outputs + + +class MultiDecoder(nn.Module): + def __init__( + self, + feat_size, + shapes, + mlp_keys, + cnn_keys, + act, + norm, + cnn_depth, + cnn_kernels, + mlp_layers, + mlp_units, + cnn_sigmoid, + image_dist, + vector_dist, + ): + super(MultiDecoder, self).__init__() + self.cnn_shapes = { + k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k) + } + self.mlp_shapes = { + k: v + for k, v in shapes.items() + if len(v) in (1, 2) and re.match(mlp_keys, k) + } + print("Decoder CNN shapes:", self.cnn_shapes) + print("Decoder MLP shapes:", self.mlp_shapes) + + if self.cnn_shapes: + some_shape = list(self.cnn_shapes.values())[0] + shape = (sum(x[-1] for x in self.cnn_shapes.values()),) + some_shape[:-1] + self._cnn = ConvDecoder( + feat_size, + shape, + cnn_depth, + act, + norm, + cnn_kernels, + cnn_sigmoid=cnn_sigmoid, + ) + if self.mlp_shapes: + self._mlp = MLP( + feat_size, + self.mlp_shapes, + mlp_layers, + mlp_units, + act, + norm, + vector_dist, + ) + self._image_dist = image_dist + + def forward(self, features): + dists = {} + if self.cnn_shapes: + feat = features + outputs = self._cnn(feat) + split_sizes = [v[-1] for v in self.cnn_shapes.values()] + outputs = torch.split(outputs, split_sizes, -1) + dists.update( + { + key: self._make_image_dist(output) + for key, output in zip(self.cnn_shapes.keys(), outputs) + } + ) + if self.mlp_shapes: + dists.update(self._mlp(features)) + return dists + + def _make_image_dist(self, mean): + if self._image_dist == "normal": + return tools.ContDist( + torchd.independent.Independent(torchd.normal.Normal(mean, 1), 3) + ) + if self._image_dist == "mse": + return tools.MSEDist(mean) + raise NotImplementedError(self._image_dist) + + +class ConvEncoder(nn.Module): + def __init__( + self, input_ch, depth=32, act="SiLU", norm="LayerNorm", kernels=(3, 3, 3, 3) ): super(ConvEncoder, self).__init__() - self._act = act - self._norm = norm + act = getattr(torch.nn, act) + norm = getattr(torch.nn, norm) self._depth = depth self._kernels = kernels h, w = 64, 64 layers = [] for i, kernel in enumerate(self._kernels): if i == 0: - if grayscale: - inp_dim = 1 - else: - inp_dim = 3 + inp_dim = input_ch else: inp_dim = 2 ** (i - 1) * self._depth depth = 2**i * self._depth @@ -365,37 +494,42 @@ class ConvEncoder(nn.Module): layers.append(act()) h, w = h // 2, w // 2 + self.outdim = depth * h * w self.layers = nn.Sequential(*layers) self.layers.apply(tools.weight_init) - def __call__(self, obs): - x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:])) + def forward(self, obs): + # (batch, time, h, w, ch) -> (batch * time, h, w, ch) + x = obs.reshape((-1,) + tuple(obs.shape[-3:])) + # (batch * time, h, w, ch) -> (batch * time, ch, h, w) x = x.permute(0, 3, 1, 2) x = self.layers(x) - # prod: product of all elements + # (batch * time, ...) -> (batch * time, -1) x = x.reshape([x.shape[0], np.prod(x.shape[1:])]) - shape = list(obs["image"].shape[:-3]) + [x.shape[-1]] - return x.reshape(shape) + # (batch * time, -1) -> (batch, time, -1) + return x.reshape(list(obs.shape[:-3]) + [x.shape[-1]]) class ConvDecoder(nn.Module): def __init__( self, inp_depth, + shape=(3, 64, 64), depth=32, act=nn.ELU, norm=nn.LayerNorm, - shape=(3, 64, 64), kernels=(3, 3, 3, 3), outscale=1.0, + cnn_sigmoid=False, ): super(ConvDecoder, self).__init__() self._inp_depth = inp_depth - self._act = act - self._norm = norm + act = getattr(torch.nn, act) + norm = getattr(torch.nn, norm) self._depth = depth self._shape = shape self._kernels = kernels + self._cnn_sigmoid = cnn_sigmoid self._embed_size = ( (64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1) ) @@ -407,7 +541,6 @@ class ConvDecoder(nn.Module): h, w = 4, 4 for i, kernel in enumerate(self._kernels): depth = self._embed_size // 16 // (2 ** (i + 1)) - act = self._act bias = False initializer = tools.weight_init if i == len(self._kernels) - 1: @@ -447,88 +580,125 @@ class ConvDecoder(nn.Module): outpad = pad * 2 - val return pad, outpad - def __call__(self, features, dtype=None): + def forward(self, features, dtype=None): x = self._linear_layer(features) + # (batch, time, -1) -> (batch * time, h, w, ch) x = x.reshape([-1, 4, 4, self._embed_size // 16]) + # (batch, time, -1) -> (batch * time, ch, h, w) x = x.permute(0, 3, 1, 2) x = self.layers(x) + # (batch, time, -1) -> (batch * time, ch, h, w) necessary??? mean = x.reshape(features.shape[:-1] + self._shape) + # (batch * time, ch, h, w) -> (batch * time, h, w, ch) mean = mean.permute(0, 1, 3, 4, 2) - return tools.SymlogDist(mean) + if self._cnn_sigmoid: + mean = F.sigmoid(mean) - 0.5 + return mean -class DenseHead(nn.Module): +class MLP(nn.Module): def __init__( self, inp_dim, shape, layers, units, - act=nn.ELU, - norm=nn.LayerNorm, + act="SiLU", + norm="LayerNorm", dist="normal", std=1.0, outscale=1.0, + symlog_inputs=False, device="cuda", ): - super(DenseHead, self).__init__() + super(MLP, self).__init__() self._shape = (shape,) if isinstance(shape, int) else shape - if len(self._shape) == 0: + if self._shape is not None and len(self._shape) == 0: self._shape = (1,) self._layers = layers - self._units = units - self._act = act - self._norm = norm + act = getattr(torch.nn, act) + norm = getattr(torch.nn, norm) self._dist = dist self._std = std + self._symlog_inputs = symlog_inputs self._device = device layers = [] for index in range(self._layers): - layers.append(nn.Linear(inp_dim, self._units, bias=False)) - layers.append(norm(self._units, eps=1e-03)) + layers.append(nn.Linear(inp_dim, units, bias=False)) + layers.append(norm(units, eps=1e-03)) layers.append(act()) if index == 0: - inp_dim = self._units + inp_dim = units self.layers = nn.Sequential(*layers) self.layers.apply(tools.weight_init) - self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape)) - self.mean_layer.apply(tools.uniform_weight_init(outscale)) + if isinstance(self._shape, dict): + self.mean_layer = nn.ModuleDict() + for name, shape in self._shape.items(): + self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape)) + self.mean_layer.apply(tools.uniform_weight_init(outscale)) + if self._std == "learned": + self.std_layer = nn.ModuleDict() + for name, shape in self._shape.items(): + self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape)) + self.std_layer.apply(tools.uniform_weight_init(outscale)) + elif self._shape is not None: + self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape)) + self.mean_layer.apply(tools.uniform_weight_init(outscale)) + if self._std == "learned": + self.std_layer = nn.Linear(units, np.prod(self._shape)) + self.std_layer.apply(tools.uniform_weight_init(outscale)) - if self._std == "learned": - self.std_layer = nn.Linear(self._units, np.prod(self._shape)) - self.std_layer.apply(tools.uniform_weight_init(outscale)) - - def __call__(self, features, dtype=None): + def forward(self, features, dtype=None): x = features + if self._symlog_inputs: + x = tools.symlog(x) out = self.layers(x) - mean = self.mean_layer(out) - if self._std == "learned": - std = self.std_layer(out) + if self._shape is None: + return out + if isinstance(self._shape, dict): + dists = {} + for name, shape in self._shape.items(): + mean = self.mean_layer[name](out) + if self._std == "learned": + std = self.std_layer[name](out) + else: + std = self._std + dists.update({name: self.dist(self._dist, mean, std, shape)}) + return dists else: - std = self._std - if self._dist == "normal": + mean = self.mean_layer(out) + if self._std == "learned": + std = self.std_layer(out) + else: + std = self._std + return self.dist(self._dist, mean, std, self._shape) + + def dist(self, dist, mean, std, shape): + if dist == "normal": return tools.ContDist( torchd.independent.Independent( - torchd.normal.Normal(mean, std), len(self._shape) + torchd.normal.Normal(mean, std), len(shape) ) ) - if self._dist == "huber": + if dist == "huber": return tools.ContDist( torchd.independent.Independent( - tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape) + tools.UnnormalizedHuber(mean, std, 1.0), len(shape) ) ) - if self._dist == "binary": + if dist == "binary": return tools.Bernoulli( torchd.independent.Independent( - torchd.bernoulli.Bernoulli(logits=mean), len(self._shape) + torchd.bernoulli.Bernoulli(logits=mean), len(shape) ) ) - if self._dist == "twohot_symlog": - return tools.TwoHotDistSymlog(logits=mean, device=self._device) - raise NotImplementedError(self._dist) + if dist == "symlog_disc": + return tools.DiscDist(logits=mean, device=self._device) + if dist == "symlog_mse": + return tools.SymlogDist(mean) + raise NotImplementedError(dist) class ActionHead(nn.Module): @@ -553,8 +723,8 @@ class ActionHead(nn.Module): self._layers = layers self._units = units self._dist = dist - self._act = act - self._norm = norm + act = getattr(torch.nn, act) + norm = getattr(torch.nn, norm) self._min_std = min_std self._max_std = max_std self._init_std = init_std @@ -579,7 +749,7 @@ class ActionHead(nn.Module): self._dist_layer = nn.Linear(self._units, self._size) self._dist_layer.apply(tools.uniform_weight_init(outscale)) - def __call__(self, features, dtype=None): + def forward(self, features, dtype=None): x = features x = self._pre_layers(x) if self._dist == "tanh_normal": diff --git a/tools.py b/tools.py index 24e0e3d..bc46903 100644 --- a/tools.py +++ b/tools.py @@ -320,24 +320,34 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): return sample -class TwoHotDistSymlog: - def __init__(self, logits=None, low=-20.0, high=20.0, device="cuda"): +class DiscDist: + def __init__( + self, + logits, + low=-20.0, + high=20.0, + transfwd=symlog, + transbwd=symexp, + device="cuda", + ): self.logits = logits self.probs = torch.softmax(logits, -1) self.buckets = torch.linspace(low, high, steps=255).to(device) self.width = (self.buckets[-1] - self.buckets[0]) / 255 + self.transfwd = transfwd + self.transbwd = transbwd def mean(self): _mean = self.probs * self.buckets - return symexp(torch.sum(_mean, dim=-1, keepdim=True)) + return self.transbwd(torch.sum(_mean, dim=-1, keepdim=True)) def mode(self): _mode = self.probs * self.buckets - return symexp(torch.sum(_mode, dim=-1, keepdim=True)) + return self.transbwd(torch.sum(_mode, dim=-1, keepdim=True)) # Inside OneHotCategorical, log_prob is calculated using only max element in targets def log_prob(self, x): - x = symlog(x) + x = self.transfwd(x) # x(time, batch, 1) below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 above = len(self.buckets) - torch.sum( @@ -366,15 +376,35 @@ class TwoHotDistSymlog: return (target * log_pred).sum(-1) +class MSEDist: + def __init__(self, mode, agg="sum"): + self._mode = mode + self._agg = agg + + def mode(self): + return self._mode + + def mean(self): + return self._mode + + def log_prob(self, value): + assert self._mode.shape == value.shape, (self._mode.shape, value.shape) + distance = (self._mode - value) ** 2 + if self._agg == "mean": + loss = distance.mean(list(range(len(distance.shape)))[2:]) + elif self._agg == "sum": + loss = distance.sum(list(range(len(distance.shape)))[2:]) + else: + raise NotImplementedError(self._agg) + return -loss + + class SymlogDist: - def __init__( - self, mode, dist="mse", agg="sum", tol=1e-8, dim_to_reduce=[-1, -2, -3] - ): + def __init__(self, mode, dist="mse", agg="sum", tol=1e-8): self._mode = mode self._dist = dist self._agg = agg self._tol = tol - self._dim_to_reduce = dim_to_reduce def mode(self): return symexp(self._mode) @@ -393,9 +423,9 @@ class SymlogDist: else: raise NotImplementedError(self._dist) if self._agg == "mean": - loss = distance.mean(self._dim_to_reduce) + loss = distance.mean(list(range(len(distance.shape)))[2:]) elif self._agg == "sum": - loss = distance.sum(self._dim_to_reduce) + loss = distance.sum(list(range(len(distance.shape)))[2:]) else: raise NotImplementedError(self._agg) return -loss