commit fb5c21557a64f1c8e633e115e163bedd46b417d7 Author: NM512 Date: Sun Feb 12 22:35:25 2023 +0900 Initial Commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3cbc5d2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,134 @@ +# +*.sh +logdir* +vis_* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..65c5fe5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 NM512 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..1810177 --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +# Dreamer-v3 Pytorch +Pytorch implementation of [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104v1) + +![image_walker_walk](https://user-images.githubusercontent.com/70328564/218313056-c1158a7d-10f3-4052-b19d-6d642ee4850b.gif) + +## Instructions +Get dependencies: +``` +pip install -r requirements.txt +``` +Train the agent: +``` +python3 dreamer.py --configs defaults --logdir $ABSOLUTEPATH_TO_SAVE_LOG +``` +Monitor results: +``` +tensorboard --logdir $ABSOLUTEPATH_TO_SAVE_LOG +``` +## Evaluation Results +work-in-progress + +![Fig](https://user-images.githubusercontent.com/70328564/218313252-3d42193a-a7c4-4fd1-bd0a-df4f4f5787d5.png) + +## Awesome Environments used for testing: +- Deepmind control suite: https://github.com/deepmind/dm_control +- will be added soon + +## Acknowledgments +This code is heavily inspired by the following works: +- danijar's Dreamer-v2 tensorflow implementation: https://github.com/danijar/dreamerv2 +- jsikyoon's Dreamer-v2 pytorch implementation: https://github.com/jsikyoon/dreamer-torch +- RajGhugare19's Dreamer-v2 pytorch implementation: https://github.com/RajGhugare19/dreamerv2 +- denisyarats's DrQ-v2 original implementation: https://github.com/facebookresearch/drqv2 diff --git a/configs.yaml b/configs.yaml new file mode 100644 index 0000000..3ce6351 --- /dev/null +++ b/configs.yaml @@ -0,0 +1,136 @@ +defaults: + + logdir: null + traindir: null + evaldir: null + offline_traindir: '' + offline_evaldir: '' + seed: 0 + steps: 5e5 + eval_every: 1e4 + log_every: 1e4 + reset_every: 0 + #gpu_growth: True + device: 'cuda:0' + precision: 16 + debug: False + expl_gifs: False + + # Environment + task: 'dmc_walker_walk' + size: [64, 64] + envs: 1 + action_repeat: 2 + time_limit: 1000 + grayscale: False + prefill: 2500 + eval_noise: 0.0 + reward_trans: 'symlog' + obs_trans: 'normalize' + critic_trans: 'symlog' + reward_EMA: True + + # Model + dyn_cell: 'gru_layer_norm' + dyn_hidden: 512 + dyn_deter: 512 + dyn_stoch: 32 + dyn_discrete: 32 + dyn_input_layers: 1 + dyn_output_layers: 1 + dyn_rec_depth: 1 + dyn_shared: False + dyn_mean_act: 'none' + dyn_std_act: 'sigmoid2' + dyn_min_std: 0.1 + dyn_temp_post: True + grad_heads: ['image', 'reward', 'discount'] + units: 256 + reward_layers: 2 + discount_layers: 2 + value_layers: 2 + actor_layers: 2 + act: 'SiLU' + norm: 'LayerNorm' + cnn_depth: 32 + encoder_kernels: [3, 3, 3, 3] + decoder_kernels: [3, 3, 3, 3] + # changed here + value_head: 'twohot' + reward_head: 'twohot' + kl_lscale: '0.1' + kl_rscale: '0.5' + kl_free: '1.0' + kl_forward: False + pred_discount: True + discount_scale: 1.0 + reward_scale: 1.0 + weight_decay: 0.0 + unimix_ratio: 0.01 + + # Training + batch_size: 16 + batch_length: 64 + train_every: 5 + train_steps: 1 + pretrain: 100 + model_lr: 1e-4 + opt_eps: 1e-8 + grad_clip: 1000 + value_lr: 3e-5 + actor_lr: 3e-5 + ac_opt_eps: 1e-5 + value_grad_clip: 100 + actor_grad_clip: 100 + dataset_size: 0 + oversample_ends: False + slow_value_target: True + slow_actor_target: True + slow_target_update: 50 + slow_target_fraction: 0.01 + opt: 'adam' + + # Behavior. + discount: 0.997 + discount_lambda: 0.95 + imag_horizon: 15 + imag_gradient: 'dynamics' + imag_gradient_mix: '0.1' + imag_sample: True + actor_dist: 'trunc_normal' + actor_entropy: '3e-4' + actor_state_entropy: 0.0 + actor_init_std: 1.0 + actor_min_std: 0.1 + actor_disc: 5 + actor_temp: 0.1 + actor_outscale: 0.0 + expl_amount: 0.0 + eval_state_mean: False + collect_dyn_sample: True + behavior_stop_grad: True + value_decay: 0.0 + future_entropy: False + + # Exploration + expl_behavior: 'greedy' + expl_until: 0 + expl_extr_scale: 0.0 + expl_intr_scale: 1.0 + disag_target: 'stoch' + disag_log: True + disag_models: 10 + disag_offset: 1 + disag_layers: 4 + disag_units: 400 + disag_action_cond: False + +debug: + + debug: True + pretrain: 1 + prefill: 1 + train_steps: 1 + batch_size: 10 + batch_length: 20 + diff --git a/dreamer.py b/dreamer.py new file mode 100644 index 0000000..b855e64 --- /dev/null +++ b/dreamer.py @@ -0,0 +1,343 @@ +import argparse +import collections +import functools +import os +import pathlib +import sys +import warnings + +os.environ["MUJOCO_GL"] = "egl" + +import numpy as np +import ruamel.yaml as yaml + +sys.path.append(str(pathlib.Path(__file__).parent)) + +import exploration as expl +import models +import tools +import wrappers + +import torch +from torch import nn +from torch import distributions as torchd + +to_np = lambda x: x.detach().cpu().numpy() + + +class Dreamer(nn.Module): + def __init__(self, config, logger, dataset): + super(Dreamer, self).__init__() + self._config = config + self._logger = logger + self._should_log = tools.Every(config.log_every) + self._should_train = tools.Every(config.train_every) + self._should_pretrain = tools.Once() + self._should_reset = tools.Every(config.reset_every) + self._should_expl = tools.Until(int(config.expl_until / config.action_repeat)) + self._metrics = {} + self._step = count_steps(config.traindir) + # Schedules. + config.actor_entropy = lambda x=config.actor_entropy: tools.schedule( + x, self._step + ) + config.actor_state_entropy = ( + lambda x=config.actor_state_entropy: tools.schedule(x, self._step) + ) + config.imag_gradient_mix = lambda x=config.imag_gradient_mix: tools.schedule( + x, self._step + ) + self._dataset = dataset + self._wm = models.WorldModel(self._step, config) + self._task_behavior = models.ImagBehavior( + config, self._wm, config.behavior_stop_grad + ) + reward = lambda f, s, a: self._wm.heads["reward"](f).mean + self._expl_behavior = dict( + greedy=lambda: self._task_behavior, + random=lambda: expl.Random(config), + plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), + )[config.expl_behavior]() + + def __call__(self, obs, reset, state=None, reward=None, training=True): + step = self._step + if self._should_reset(step): + state = None + if state is not None and reset.any(): + mask = 1 - reset + for key in state[0].keys(): + for i in range(state[0][key].shape[0]): + state[0][key][i] *= mask[i] + for i in range(len(state[1])): + state[1][i] *= mask[i] + if training and self._should_train(step): + steps = ( + self._config.pretrain + if self._should_pretrain() + else self._config.train_steps + ) + for _ in range(steps): + self._train(next(self._dataset)) + if self._should_log(step): + for name, values in self._metrics.items(): + self._logger.scalar(name, float(np.mean(values))) + self._metrics[name] = [] + openl = self._wm.video_pred(next(self._dataset)) + self._logger.video("train_openl", to_np(openl)) + self._logger.write(fps=True) + + policy_output, state = self._policy(obs, state, training) + + if training: + self._step += len(reset) + self._logger.step = self._config.action_repeat * self._step + return policy_output, state + + def _policy(self, obs, state, training): + if state is None: + batch_size = len(obs["image"]) + latent = self._wm.dynamics.initial(len(obs["image"])) + action = torch.zeros((batch_size, self._config.num_actions)).to( + self._config.device + ) + else: + latent, action = state + embed = self._wm.encoder(self._wm.preprocess(obs)) + latent, _ = self._wm.dynamics.obs_step( + latent, action, embed, self._config.collect_dyn_sample + ) + if self._config.eval_state_mean: + latent["stoch"] = latent["mean"] + feat = self._wm.dynamics.get_feat(latent) + if not training: + actor = self._task_behavior.actor(feat) + action = actor.mode() + elif self._should_expl(self._step): + actor = self._expl_behavior.actor(feat) + action = actor.sample() + else: + actor = self._task_behavior.actor(feat) + action = actor.sample() + logprob = actor.log_prob(action) + latent = {k: v.detach() for k, v in latent.items()} + action = action.detach() + if self._config.actor_dist == "onehot_gumble": + action = torch.one_hot( + torch.argmax(action, dim=-1), self._config.num_actions + ) + action = self._exploration(action, training) + policy_output = {"action": action, "logprob": logprob} + state = (latent, action) + return policy_output, state + + def _exploration(self, action, training): + amount = self._config.expl_amount if training else self._config.eval_noise + if amount == 0: + return action + if "onehot" in self._config.actor_dist: + probs = amount / self._config.num_actions + (1 - amount) * action + return tools.OneHotDist(probs=probs).sample() + else: + return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1) + raise NotImplementedError(self._config.action_noise) + + def _train(self, data): + metrics = {} + post, context, mets = self._wm._train(data) + metrics.update(mets) + start = post + if self._config.pred_discount: # Last step could be terminal. + start = {k: v[:, :-1] for k, v in post.items()} + context = {k: v[:, :-1] for k, v in context.items()} + reward = lambda f, s, a: self._wm.heads["reward"]( + self._wm.dynamics.get_feat(s) + ).mode() + metrics.update(self._task_behavior._train(start, reward)[-1]) + if self._config.expl_behavior != "greedy": + if self._config.pred_discount: + data = {k: v[:, :-1] for k, v in data.items()} + mets = self._expl_behavior.train(start, context, data)[-1] + metrics.update({"expl_" + key: value for key, value in mets.items()}) + for name, value in metrics.items(): + if not name in self._metrics.keys(): + self._metrics[name] = [value] + else: + self._metrics[name].append(value) + + +def count_steps(folder): + return sum(int(str(n).split("-")[-1][:-4]) - 1 for n in folder.glob("*.npz")) + + +def make_dataset(episodes, config): + generator = tools.sample_episodes( + episodes, config.batch_length, config.oversample_ends + ) + dataset = tools.from_generator(generator, config.batch_size) + return dataset + + +def make_env(config, logger, mode, train_eps, eval_eps): + suite, task = config.task.split("_", 1) + if suite == "dmc": + env = wrappers.DeepMindControl(task, config.action_repeat, config.size) + env = wrappers.NormalizeActions(env) + elif suite == "atari": + env = wrappers.Atari( + task, + config.action_repeat, + config.size, + grayscale=config.grayscale, + life_done=False and ("train" in mode), + sticky_actions=True, + all_actions=True, + ) + env = wrappers.OneHotAction(env) + elif suite == "dmlab": + env = wrappers.DeepMindLabyrinth( + task, mode if "train" in mode else "test", config.action_repeat + ) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit) + env = wrappers.SelectAction(env, key="action") + if (mode == "train") or (mode == "eval"): + callbacks = [ + functools.partial( + process_episode, config, logger, mode, train_eps, eval_eps + ) + ] + env = wrappers.CollectDataset(env, callbacks) + env = wrappers.RewardObs(env) + return env + + +def process_episode(config, logger, mode, train_eps, eval_eps, episode): + directory = dict(train=config.traindir, eval=config.evaldir)[mode] + cache = dict(train=train_eps, eval=eval_eps)[mode] + filename = tools.save_episodes(directory, [episode])[0] + length = len(episode["reward"]) - 1 + score = float(episode["reward"].astype(np.float64).sum()) + video = episode["image"] + if mode == "eval": + cache.clear() + if mode == "train" and config.dataset_size: + total = 0 + for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): + if total <= config.dataset_size - length: + total += len(ep["reward"]) - 1 + else: + del cache[key] + logger.scalar("dataset_size", total + length) + cache[str(filename)] = episode + print(f"{mode.title()} episode has {length} steps and return {score:.1f}.") + logger.scalar(f"{mode}_return", score) + logger.scalar(f"{mode}_length", length) + logger.scalar(f"{mode}_episodes", len(cache)) + if mode == "eval" or config.expl_gifs: + logger.video(f"{mode}_policy", video[None]) + logger.write() + + +def main(config): + logdir = pathlib.Path(config.logdir).expanduser() + config.traindir = config.traindir or logdir / "train_eps" + config.evaldir = config.evaldir or logdir / "eval_eps" + config.steps //= config.action_repeat + config.eval_every //= config.action_repeat + config.log_every //= config.action_repeat + config.time_limit //= config.action_repeat + config.act = getattr(torch.nn, config.act) + config.norm = getattr(torch.nn, config.norm) + + print("Logdir", logdir) + logdir.mkdir(parents=True, exist_ok=True) + config.traindir.mkdir(parents=True, exist_ok=True) + config.evaldir.mkdir(parents=True, exist_ok=True) + step = count_steps(config.traindir) + logger = tools.Logger(logdir, config.action_repeat * step) + + print("Create envs.") + if config.offline_traindir: + directory = config.offline_traindir.format(**vars(config)) + else: + directory = config.traindir + train_eps = tools.load_episodes(directory, limit=config.dataset_size) + if config.offline_evaldir: + directory = config.offline_evaldir.format(**vars(config)) + else: + directory = config.evaldir + eval_eps = tools.load_episodes(directory, limit=1) + make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps) + train_envs = [make("train") for _ in range(config.envs)] + eval_envs = [make("eval") for _ in range(config.envs)] + acts = train_envs[0].action_space + config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] + + if not config.offline_traindir: + prefill = max(0, config.prefill - count_steps(config.traindir)) + print(f"Prefill dataset ({prefill} steps).") + if hasattr(acts, "discrete"): + random_actor = tools.OneHotDist( + torch.zeros_like(torch.Tensor(acts.low))[None] + ) + else: + random_actor = torchd.independent.Independent( + torchd.uniform.Uniform( + torch.Tensor(acts.low)[None], torch.Tensor(acts.high)[None] + ), + 1, + ) + + def random_agent(o, d, s, r): + action = random_actor.sample() + logprob = random_actor.log_prob(action) + return {"action": action, "logprob": logprob}, None + + tools.simulate(random_agent, train_envs, prefill) + tools.simulate(random_agent, eval_envs, episodes=1) + logger.step = config.action_repeat * count_steps(config.traindir) + + print("Simulate agent.") + train_dataset = make_dataset(train_eps, config) + eval_dataset = make_dataset(eval_eps, config) + agent = Dreamer(config, logger, train_dataset).to(config.device) + agent.requires_grad_(requires_grad=False) + if (logdir / "latest_model.pt").exists(): + agent.load_state_dict(torch.load(logdir / "latest_model.pt")) + agent._should_pretrain._once = False + + state = None + while agent._step < config.steps: + logger.write() + print("Start evaluation.") + video_pred = agent._wm.video_pred(next(eval_dataset)) + logger.video("eval_openl", to_np(video_pred)) + eval_policy = functools.partial(agent, training=False) + tools.simulate(eval_policy, eval_envs, episodes=1) + print("Start training.") + state = tools.simulate(agent, train_envs, config.eval_every, state=state) + torch.save(agent.state_dict(), logdir / "latest_model.pt") + for env in train_envs + eval_envs: + try: + env.close() + except Exception: + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--configs", nargs="+", required=True) + args, remaining = parser.parse_known_args() + configs = yaml.safe_load( + (pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text() + ) + defaults = {} + for name in args.configs: + defaults.update(configs[name]) + parser = argparse.ArgumentParser() + for key, value in sorted(defaults.items(), key=lambda x: x[0]): + arg_type = tools.args_type(value) + parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value)) + main(parser.parse_args(remaining)) diff --git a/exploration.py b/exploration.py new file mode 100644 index 0000000..77cbc3b --- /dev/null +++ b/exploration.py @@ -0,0 +1,108 @@ +import torch +from torch import nn +from torch import distributions as torchd + +import models +import networks +import tools + + +class Random(nn.Module): + def __init__(self, config): + self._config = config + + def actor(self, feat): + shape = feat.shape[:-1] + [self._config.num_actions] + if self._config.actor_dist == "onehot": + return tools.OneHotDist(torch.zeros(shape)) + else: + ones = torch.ones(shape) + return tools.ContDist(torchd.uniform.Uniform(-ones, ones)) + + def train(self, start, context): + return None, {} + + +# class Plan2Explore(tools.Module): +class Plan2Explore(nn.Module): + def __init__(self, config, world_model, reward=None): + self._config = config + self._reward = reward + self._behavior = models.ImagBehavior(config, world_model) + self.actor = self._behavior.actor + stoch_size = config.dyn_stoch + if config.dyn_discrete: + stoch_size *= config.dyn_discrete + size = { + "embed": 32 * config.cnn_depth, + "stoch": stoch_size, + "deter": config.dyn_deter, + "feat": config.dyn_stoch + config.dyn_deter, + }[self._config.disag_target] + kw = dict( + inp_dim=config.dyn_stoch, # pytorch version + shape=size, + layers=config.disag_layers, + units=config.disag_units, + act=config.act, + ) + self._networks = [networks.DenseHead(**kw) for _ in range(config.disag_models)] + self._opt = tools.optimizer( + config.opt, + self.parameters(), + config.model_lr, + config.opt_eps, + config.weight_decay, + ) + # self._opt = tools.Optimizer( + # 'ensemble', config.model_lr, config.opt_eps, config.grad_clip, + # config.weight_decay, opt=config.opt) + + def train(self, start, context, data): + metrics = {} + stoch = start["stoch"] + if self._config.dyn_discrete: + stoch = tf.reshape( + stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]) + ) + target = { + "embed": context["embed"], + "stoch": stoch, + "deter": start["deter"], + "feat": context["feat"], + }[self._config.disag_target] + inputs = context["feat"] + if self._config.disag_action_cond: + inputs = tf.concat([inputs, data["action"]], -1) + metrics.update(self._train_ensemble(inputs, target)) + metrics.update(self._behavior.train(start, self._intrinsic_reward)[-1]) + return None, metrics + + def _intrinsic_reward(self, feat, state, action): + inputs = feat + if self._config.disag_action_cond: + inputs = tf.concat([inputs, action], -1) + preds = [head(inputs, tf.float32).mean() for head in self._networks] + disag = tf.reduce_mean(tf.math.reduce_std(preds, 0), -1) + if self._config.disag_log: + disag = tf.math.log(disag) + reward = self._config.expl_intr_scale * disag + if self._config.expl_extr_scale: + reward += tf.cast( + self._config.expl_extr_scale * self._reward(feat, state, action), + tf.float32, + ) + return reward + + def _train_ensemble(self, inputs, targets): + if self._config.disag_offset: + targets = targets[:, self._config.disag_offset :] + inputs = inputs[:, : -self._config.disag_offset] + targets = tf.stop_gradient(targets) + inputs = tf.stop_gradient(inputs) + with tf.GradientTape() as tape: + preds = [head(inputs) for head in self._networks] + likes = [tf.reduce_mean(pred.log_prob(targets)) for pred in preds] + loss = -tf.cast(tf.reduce_sum(likes), tf.float32) + metrics = self._opt(tape, loss, self._networks) + return metrics diff --git a/models.py b/models.py new file mode 100644 index 0000000..7489a17 --- /dev/null +++ b/models.py @@ -0,0 +1,509 @@ +import copy +import torch +from torch import nn +import numpy as np +from PIL import ImageColor, Image, ImageDraw, ImageFont + +import networks +import tools + +to_np = lambda x: x.detach().cpu().numpy() + + +def symlog(x): + return torch.sign(x) * torch.log(torch.abs(x) + 1.0) + + +def symexp(x): + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0) + + +class RewardEMA(object): + """running mean and std""" + + def __init__(self, device, alpha=1e-2): + self.device = device + self.scale = torch.zeros((1,)).to(device) + self.alpha = alpha + self.range = torch.tensor([0.05, 0.95]).to(device) + + def __call__(self, x): + flat_x = torch.flatten(x.detach()) + x_quantile = torch.quantile(input=flat_x, q=self.range) + scale = x_quantile[1] - x_quantile[0] + new_scale = self.alpha * scale + (1 - self.alpha) * self.scale + self.scale = new_scale + return x / torch.clip(self.scale, min=1.0) + + +class WorldModel(nn.Module): + def __init__(self, step, config): + super(WorldModel, self).__init__() + self._step = step + self._use_amp = True if config.precision == 16 else False + self._config = config + self.encoder = networks.ConvEncoder( + config.grayscale, + config.cnn_depth, + config.act, + config.norm, + config.encoder_kernels, + ) + if config.size[0] == 64 and config.size[1] == 64: + embed_size = ( + (64 // 2 ** (len(config.encoder_kernels))) ** 2 + * config.cnn_depth + * 2 ** (len(config.encoder_kernels) - 1) + ) + else: + raise NotImplemented(f"{config.size} is not applicable now") + self.dynamics = networks.RSSM( + config.dyn_stoch, + config.dyn_deter, + config.dyn_hidden, + config.dyn_input_layers, + config.dyn_output_layers, + config.dyn_rec_depth, + config.dyn_shared, + config.dyn_discrete, + config.act, + config.norm, + config.dyn_mean_act, + config.dyn_std_act, + config.dyn_temp_post, + config.dyn_min_std, + config.dyn_cell, + config.unimix_ratio, + config.num_actions, + embed_size, + config.device, + ) + self.heads = nn.ModuleDict() + channels = 1 if config.grayscale else 3 + shape = (channels,) + config.size + if config.dyn_discrete: + feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter + else: + feat_size = config.dyn_stoch + config.dyn_deter + self.heads["image"] = networks.ConvDecoder( + feat_size, # pytorch version + config.cnn_depth, + config.act, + config.norm, + shape, + config.decoder_kernels, + ) + if config.reward_head == "twohot": + self.heads["reward"] = networks.DenseHead( + feat_size, # pytorch version + (255,), + config.reward_layers, + config.units, + config.act, + config.norm, + dist=config.reward_head, + ) + else: + self.heads["reward"] = networks.DenseHead( + feat_size, # pytorch version + [], + config.reward_layers, + config.units, + config.act, + config.norm, + dist=config.reward_head, + ) + # added this + self.heads["reward"].apply(tools.weight_init) + if config.pred_discount: + self.heads["discount"] = networks.DenseHead( + feat_size, # pytorch version + [], + config.discount_layers, + config.units, + config.act, + config.norm, + dist="binary", + ) + for name in config.grad_heads: + assert name in self.heads, name + self._model_opt = tools.Optimizer( + "model", + self.parameters(), + config.model_lr, + config.opt_eps, + config.grad_clip, + config.weight_decay, + opt=config.opt, + use_amp=self._use_amp, + ) + self._scales = dict(reward=config.reward_scale, discount=config.discount_scale) + + def _train(self, data): + # action (batch_size, batch_length, act_dim) + # image (batch_size, batch_length, h, w, ch) + # reward (batch_size, batch_length) + # discount (batch_size, batch_length) + data = self.preprocess(data) + + with tools.RequiresGrad(self): + with torch.cuda.amp.autocast(self._use_amp): + embed = self.encoder(data) + post, prior = self.dynamics.observe(embed, data["action"]) + kl_free = tools.schedule(self._config.kl_free, self._step) + kl_lscale = tools.schedule(self._config.kl_lscale, self._step) + kl_rscale = tools.schedule(self._config.kl_rscale, self._step) + kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( + post, prior, self._config.kl_forward, kl_free, kl_lscale, kl_rscale + ) + losses = {} + likes = {} + for name, head in self.heads.items(): + grad_head = name in self._config.grad_heads + feat = self.dynamics.get_feat(post) + feat = feat if grad_head else feat.detach() + pred = head(feat) + # if name == 'image': + # losses[name] = torch.nn.functional.mse_loss(pred.mode(), data[name], 'sum') + like = pred.log_prob(data[name]) + likes[name] = like + losses[name] = -torch.mean(like) * self._scales.get(name, 1.0) + model_loss = sum(losses.values()) + kl_loss + metrics = self._model_opt(model_loss, self.parameters()) + + metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()}) + metrics["kl_free"] = kl_free + metrics["kl_lscale"] = kl_lscale + metrics["kl_rscale"] = kl_rscale + metrics["loss_lhs"] = to_np(loss_lhs) + metrics["loss_rhs"] = to_np(loss_rhs) + metrics["kl"] = to_np(torch.mean(kl_value)) + with torch.cuda.amp.autocast(self._use_amp): + metrics["prior_ent"] = to_np( + torch.mean(self.dynamics.get_dist(prior).entropy()) + ) + metrics["post_ent"] = to_np( + torch.mean(self.dynamics.get_dist(post).entropy()) + ) + context = dict( + embed=embed, + feat=self.dynamics.get_feat(post), + kl=kl_value, + postent=self.dynamics.get_dist(post).entropy(), + ) + post = {k: v.detach() for k, v in post.items()} + return post, context, metrics + + def preprocess(self, obs): + obs = obs.copy() + if self._config.obs_trans == "normalize": + obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5 + elif self._config.obs_trans == "identity": + obs["image"] = torch.Tensor(obs["image"]) + elif self._config.obs_trans == "symlog": + obs["image"] = symlog(torch.Tensor(obs["image"])) + else: + raise NotImplemented(f"{self._config.reward_trans} is not implemented") + if self._config.reward_trans == "tanh": + # (batch_size, batch_length) -> (batch_size, batch_length, 1) + obs["reward"] = torch.tanh(torch.Tensor(obs["reward"])).unsqueeze(-1) + elif self._config.reward_trans == "identity": + # (batch_size, batch_length) -> (batch_size, batch_length, 1) + obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1) + elif self._config.reward_trans == "symlog": + obs["reward"] = symlog(torch.Tensor(obs["reward"])).unsqueeze(-1) + else: + raise NotImplemented(f"{self._config.reward_trans} is not implemented") + if "discount" in obs: + obs["discount"] *= self._config.discount + # (batch_size, batch_length) -> (batch_size, batch_length, 1) + obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1) + obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()} + return obs + + def video_pred(self, data): + data = self.preprocess(data) + embed = self.encoder(data) + + states, _ = self.dynamics.observe(embed[:6, :5], data["action"][:6, :5]) + recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6] + reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6] + init = {k: v[:, -1] for k, v in states.items()} + prior = self.dynamics.imagine(data["action"][:6, 5:], init) + openl = self.heads["image"](self.dynamics.get_feat(prior)).mode() + reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode() + # observed image is given until 5 steps + model = torch.cat([recon[:, :5], openl], 1) + if self._config.obs_trans == "normalize": + truth = data["image"][:6] + 0.5 + model += 0.5 + elif self._config.obs_trans == "symlog": + truth = symexp(data["image"][:6]) / 255.0 + model = symexp(model) / 255.0 + error = (model - truth + 1) / 2 + + return torch.cat([truth, model, error], 2) + + +class ImagBehavior(nn.Module): + def __init__(self, config, world_model, stop_grad_actor=True, reward=None): + super(ImagBehavior, self).__init__() + self._use_amp = True if config.precision == 16 else False + self._config = config + self._world_model = world_model + self._stop_grad_actor = stop_grad_actor + self._reward = reward + if config.dyn_discrete: + feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter + else: + feat_size = config.dyn_stoch + config.dyn_deter + self.actor = networks.ActionHead( + feat_size, # pytorch version + config.num_actions, + config.actor_layers, + config.units, + config.act, + config.norm, + config.actor_dist, + config.actor_init_std, + config.actor_min_std, + config.actor_dist, + config.actor_temp, + config.actor_outscale, + ) # action_dist -> action_disc? + if config.value_head == "twohot": + self.value = networks.DenseHead( + feat_size, # pytorch version + (255,), + config.value_layers, + config.units, + config.act, + config.norm, + config.value_head, + ) + else: + self.value = networks.DenseHead( + feat_size, # pytorch version + [], + config.value_layers, + config.units, + config.act, + config.norm, + config.value_head, + ) + self.value.apply(tools.weight_init) + if config.slow_value_target or config.slow_actor_target: + self._slow_value = copy.deepcopy(self.value) + self._updates = 0 + kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) + self._actor_opt = tools.Optimizer( + "actor", + self.actor.parameters(), + config.actor_lr, + config.ac_opt_eps, + config.actor_grad_clip, + **kw, + ) + self._value_opt = tools.Optimizer( + "value", + self.value.parameters(), + config.value_lr, + config.ac_opt_eps, + config.value_grad_clip, + **kw, + ) + if self._config.reward_EMA: + self.reward_ema = RewardEMA(device=self._config.device) + + def _train( + self, + start, + objective=None, + action=None, + reward=None, + imagine=None, + tape=None, + repeats=None, + ): + objective = objective or self._reward + self._update_slow_target() + metrics = {} + + with tools.RequiresGrad(self.actor): + with torch.cuda.amp.autocast(self._use_amp): + imag_feat, imag_state, imag_action = self._imagine( + start, self.actor, self._config.imag_horizon, repeats + ) + reward = objective(imag_feat, imag_state, imag_action) + if self._config.reward_trans == "symlog": + # rescale predicted reward by head['reward'] + reward = symexp(reward) + actor_ent = self.actor(imag_feat).entropy() + state_ent = self._world_model.dynamics.get_dist(imag_state).entropy() + # this target is not scaled + # slow is flag to indicate whether slow_target is used for lambda-return + target, weights = self._compute_target( + imag_feat, + imag_state, + imag_action, + reward, + actor_ent, + state_ent, + self._config.slow_actor_target, + ) + actor_loss, mets = self._compute_actor_loss( + imag_feat, + imag_state, + imag_action, + target, + actor_ent, + state_ent, + weights, + ) + metrics.update(mets) + if self._config.slow_value_target != self._config.slow_actor_target: + target, weights = self._compute_target( + imag_feat, + imag_state, + imag_action, + reward, + actor_ent, + state_ent, + self._config.slow_value_target, + ) + value_input = imag_feat + + with tools.RequiresGrad(self.value): + with torch.cuda.amp.autocast(self._use_amp): + value = self.value(value_input[:-1].detach()) + target = torch.stack(target, dim=1) + # only critic target is processed using symlog(not actor) + if self._config.critic_trans == "symlog": + metrics["unscaled_target_mean"] = to_np(torch.mean(target)) + target = symlog(target) + # (time, batch, 1), (time, batch, 1) -> (time, batch) + value_loss = -value.log_prob(target.detach()) + if self._config.value_decay: + value_loss += self._config.value_decay * value.mode() + # (time, batch, 1), (time, batch, 1) -> (1,) + value_loss = torch.mean(weights[:-1] * value_loss[:, :, None]) + + metrics["value_mean"] = to_np(torch.mean(value.mode())) + metrics["value_max"] = to_np(torch.max(value.mode())) + metrics["value_min"] = to_np(torch.min(value.mode())) + metrics["value_std"] = to_np(torch.std(value.mode())) + metrics["target_mean"] = to_np(torch.mean(target)) + metrics["reward_mean"] = to_np(torch.mean(reward)) + metrics["reward_std"] = to_np(torch.std(reward)) + metrics["actor_ent"] = to_np(torch.mean(actor_ent)) + with tools.RequiresGrad(self): + metrics.update(self._actor_opt(actor_loss, self.actor.parameters())) + metrics.update(self._value_opt(value_loss, self.value.parameters())) + return imag_feat, imag_state, imag_action, weights, metrics + + def _imagine(self, start, policy, horizon, repeats=None): + dynamics = self._world_model.dynamics + if repeats: + raise NotImplemented("repeats is not implemented in this version") + flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) + start = {k: flatten(v) for k, v in start.items()} + + def step(prev, _): + state, _, _ = prev + feat = dynamics.get_feat(state) + inp = feat.detach() if self._stop_grad_actor else feat + action = policy(inp).sample() + succ = dynamics.img_step(state, action, sample=self._config.imag_sample) + return succ, feat, action + + feat = 0 * dynamics.get_feat(start) + action = policy(feat).mode() + succ, feats, actions = tools.static_scan( + step, [torch.arange(horizon)], (start, feat, action) + ) + states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} + if repeats: + raise NotImplemented("repeats is not implemented in this version") + + return feats, states, actions + + def _compute_target( + self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent, slow + ): + if "discount" in self._world_model.heads: + inp = self._world_model.dynamics.get_feat(imag_state) + discount = self._world_model.heads["discount"](inp).mean + else: + discount = self._config.discount * torch.ones_like(reward) + if self._config.future_entropy and self._config.actor_entropy() > 0: + reward += self._config.actor_entropy() * actor_ent + if self._config.future_entropy and self._config.actor_state_entropy() > 0: + reward += self._config.actor_state_entropy() * state_ent + if slow: + value = self._slow_value(imag_feat).mode() + else: + value = self.value(imag_feat).mode() + if self._config.critic_trans == "symlog": + # After adding this line there is issue + value = symexp(value) + target = tools.lambda_return( + reward[:-1], + value[:-1], + discount[:-1], + bootstrap=value[-1], + lambda_=self._config.discount_lambda, + axis=0, + ) + weights = torch.cumprod( + torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0 + ).detach() + return target, weights + + def _compute_actor_loss( + self, imag_feat, imag_state, imag_action, target, actor_ent, state_ent, weights + ): + metrics = {} + inp = imag_feat.detach() if self._stop_grad_actor else imag_feat + policy = self.actor(inp) + actor_ent = policy.entropy() + # Q-val for actor is not transformed using symlog + target = torch.stack(target, dim=1) + if self._config.reward_EMA: + target = self.reward_ema(target) + metrics["EMA_scale"] = to_np(self.reward_ema.scale) + + if self._config.imag_gradient == "dynamics": + actor_target = target + elif self._config.imag_gradient == "reinforce": + actor_target = ( + policy.log_prob(imag_action)[:-1][:, :, None] + * (target - self.value(imag_feat[:-1]).mode()).detach() + ) + elif self._config.imag_gradient == "both": + actor_target = ( + policy.log_prob(imag_action)[:-1][:, :, None] + * (target - self.value(imag_feat[:-1]).mode()).detach() + ) + mix = self._config.imag_gradient_mix() + actor_target = mix * target + (1 - mix) * actor_target + metrics["imag_gradient_mix"] = mix + else: + raise NotImplementedError(self._config.imag_gradient) + if not self._config.future_entropy and (self._config.actor_entropy() > 0): + actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None] + actor_target += actor_entropy + metrics["actor_entropy"] = to_np(torch.mean(actor_entropy)) + if not self._config.future_entropy and (self._config.actor_state_entropy() > 0): + state_entropy = self._config.actor_state_entropy() * state_ent[:-1] + actor_target += state_entropy + metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy)) + actor_loss = -torch.mean(weights[:-1] * actor_target) + return actor_loss, metrics + + def _update_slow_target(self): + if self._config.slow_value_target or self._config.slow_actor_target: + if self._updates % self._config.slow_target_update == 0: + mix = self._config.slow_target_fraction + for s, d in zip(self.value.parameters(), self._slow_value.parameters()): + d.data = mix * s.data + (1 - mix) * d.data + self._updates += 1 diff --git a/networks.py b/networks.py new file mode 100644 index 0000000..30f6817 --- /dev/null +++ b/networks.py @@ -0,0 +1,631 @@ +import math +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +from torch import distributions as torchd + +import tools + + +class RSSM(nn.Module): + def __init__( + self, + stoch=30, + deter=200, + hidden=200, + layers_input=1, + layers_output=1, + rec_depth=1, + shared=False, + discrete=False, + act=nn.ELU, + norm=nn.LayerNorm, + mean_act="none", + std_act="softplus", + temp_post=True, + min_std=0.1, + cell="gru", + unimix_ratio=0.01, + num_actions=None, + embed=None, + device=None, + ): + super(RSSM, self).__init__() + self._stoch = stoch + self._deter = deter + self._hidden = hidden + self._min_std = min_std + self._layers_input = layers_input + self._layers_output = layers_output + self._rec_depth = rec_depth + self._shared = shared + self._discrete = discrete + self._act = act + self._norm = norm + self._mean_act = mean_act + self._std_act = std_act + self._temp_post = temp_post + self._unimix_ratio = unimix_ratio + self._embed = embed + self._device = device + + inp_layers = [] + if self._discrete: + inp_dim = self._stoch * self._discrete + num_actions + else: + inp_dim = self._stoch + num_actions + if self._shared: + inp_dim += self._embed + for i in range(self._layers_input): + inp_layers.append(nn.Linear(inp_dim, self._hidden)) + inp_layers.append(self._act()) + if i == 0: + inp_dim = self._hidden + self._inp_layers = nn.Sequential(*inp_layers) + + if cell == "gru": + self._cell = GRUCell(self._hidden, self._deter) + elif cell == "gru_layer_norm": + self._cell = GRUCell(self._hidden, self._deter, norm=True) + else: + raise NotImplementedError(cell) + + img_out_layers = [] + inp_dim = self._deter + for i in range(self._layers_output): + img_out_layers.append(nn.Linear(inp_dim, self._hidden)) + img_out_layers.append(self._norm(self._hidden)) + img_out_layers.append(self._act()) + if i == 0: + inp_dim = self._hidden + self._img_out_layers = nn.Sequential(*img_out_layers) + + obs_out_layers = [] + if self._temp_post: + inp_dim = self._deter + self._embed + else: + inp_dim = self._embed + for i in range(self._layers_output): + obs_out_layers.append(nn.Linear(inp_dim, self._hidden)) + obs_out_layers.append(self._norm(self._hidden)) + obs_out_layers.append(self._act()) + if i == 0: + inp_dim = self._hidden + self._obs_out_layers = nn.Sequential(*obs_out_layers) + + if self._discrete: + self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) + self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) + else: + self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) + self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) + + def initial(self, batch_size): + deter = torch.zeros(batch_size, self._deter).to(self._device) + if self._discrete: + state = dict( + logit=torch.zeros([batch_size, self._stoch, self._discrete]).to( + self._device + ), + stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to( + self._device + ), + deter=deter, + ) + else: + state = dict( + mean=torch.zeros([batch_size, self._stoch]).to(self._device), + std=torch.zeros([batch_size, self._stoch]).to(self._device), + stoch=torch.zeros([batch_size, self._stoch]).to(self._device), + deter=deter, + ) + return state + + def observe(self, embed, action, state=None): + swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) + if state is None: + state = self.initial(action.shape[0]) + # (batch, time, ch) -> (time, batch, ch) + embed, action = swap(embed), swap(action) + post, prior = tools.static_scan( + lambda prev_state, prev_act, embed: self.obs_step( + prev_state[0], prev_act, embed + ), + (action, embed), + (state, state), + ) + + # (batch, time, stoch, discrete_num) -> (batch, time, stoch, discrete_num) + post = {k: swap(v) for k, v in post.items()} + prior = {k: swap(v) for k, v in prior.items()} + return post, prior + + def imagine(self, action, state=None): + swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) + if state is None: + state = self.initial(action.shape[0]) + assert isinstance(state, dict), state + action = action + action = swap(action) + prior = tools.static_scan(self.img_step, [action], state) + prior = prior[0] + prior = {k: swap(v) for k, v in prior.items()} + return prior + + def get_feat(self, state): + stoch = state["stoch"] + if self._discrete: + shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete] + stoch = stoch.reshape(shape) + return torch.cat([stoch, state["deter"]], -1) + + def get_dist(self, state, dtype=None): + if self._discrete: + logit = state["logit"] + dist = torchd.independent.Independent( + tools.OneHotDist(logit, unimix_ratio=self._unimix_ratio), 1 + ) + else: + mean, std = state["mean"], state["std"] + dist = tools.ContDist( + torchd.independent.Independent(torchd.normal.Normal(mean, std), 1) + ) + return dist + + def obs_step(self, prev_state, prev_action, embed, sample=True): + # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) + # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs + prior = self.img_step(prev_state, prev_action, None, sample) + if self._shared: + post = self.img_step(prev_state, prev_action, embed, sample) + else: + if self._temp_post: + x = torch.cat([prior["deter"], embed], -1) + else: + x = embed + # (batch_size, prior_deter + embed) -> (batch_size, hidden) + x = self._obs_out_layers(x) + # (batch_size, hidden) -> (batch_size, stoch, discrete_num) + stats = self._suff_stats_layer("obs", x) + if sample: + stoch = self.get_dist(stats).sample() + else: + stoch = self.get_dist(stats).mode() + post = {"stoch": stoch, "deter": prior["deter"], **stats} + return post, prior + + # this is used for making future image + def img_step(self, prev_state, prev_action, embed=None, sample=True): + # (batch, stoch, discrete_num) + prev_stoch = prev_state["stoch"] + if self._discrete: + shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] + # (batch, stoch, discrete_num) -> (batch, stoch * discrete_num) + prev_stoch = prev_stoch.reshape(shape) + if self._shared: + if embed is None: + shape = list(prev_action.shape[:-1]) + [self._embed] + embed = torch.zeros(shape) + # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed) + x = torch.cat([prev_stoch, prev_action, embed], -1) + else: + x = torch.cat([prev_stoch, prev_action], -1) + # (batch, stoch * discrete_num + action, embed) -> (batch, hidden) + x = self._inp_layers(x) + for _ in range(self._rec_depth): # rec depth is not correctly implemented + deter = prev_state["deter"] + # (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter) + x, deter = self._cell(x, [deter]) + deter = deter[0] # Keras wraps the state in a list. + # (batch, deter) -> (batch, hidden) + x = self._img_out_layers(x) + # (batch, hidden) -> (batch_size, stoch, discrete_num) + stats = self._suff_stats_layer("ims", x) + if sample: + stoch = self.get_dist(stats).sample() + else: + stoch = self.get_dist(stats).mode() + prior = {"stoch": stoch, "deter": deter, **stats} + return prior + + def _suff_stats_layer(self, name, x): + if self._discrete: + if name == "ims": + x = self._ims_stat_layer(x) + elif name == "obs": + x = self._obs_stat_layer(x) + else: + raise NotImplementedError + logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete]) + return {"logit": logit} + else: + if name == "ims": + x = self._ims_stat_layer(x) + elif name == "obs": + x = self._obs_stat_layer(x) + else: + raise NotImplementedError + mean, std = torch.split(x, [self._stoch] * 2, -1) + mean = { + "none": lambda: mean, + "tanh5": lambda: 5.0 * torch.tanh(mean / 5.0), + }[self._mean_act]() + std = { + "softplus": lambda: torch.softplus(std), + "abs": lambda: torch.abs(std + 1), + "sigmoid": lambda: torch.sigmoid(std), + "sigmoid2": lambda: 2 * torch.sigmoid(std / 2), + }[self._std_act]() + std = std + self._min_std + return {"mean": mean, "std": std} + + def kl_loss(self, post, prior, forward, free, lscale, rscale): + kld = torchd.kl.kl_divergence + dist = lambda x: self.get_dist(x) + sg = lambda x: {k: v.detach() for k, v in x.items()} + # forward == false -> (post, prior) + lhs, rhs = (prior, post) if forward else (post, prior) + + # forward == false -> Lrep + value_lhs = value = kld( + dist(lhs) if self._discrete else dist(lhs)._dist, + dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist, + ) + # forward == false -> Ldyn + value_rhs = kld( + dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist, + dist(rhs) if self._discrete else dist(rhs)._dist, + ) + loss_lhs = torch.clip(torch.mean(value_lhs), min=free) + loss_rhs = torch.clip(torch.mean(value_rhs), min=free) + loss = lscale * loss_lhs + rscale * loss_rhs + + return loss, value, loss_lhs, loss_rhs + + +class ConvEncoder(nn.Module): + def __init__( + self, + grayscale=False, + depth=32, + act=nn.ELU, + norm=nn.LayerNorm, + kernels=(3, 3, 3, 3), + ): + super(ConvEncoder, self).__init__() + self._act = act + self._norm = norm + self._depth = depth + self._kernels = kernels + h, w = 64, 64 + layers = [] + for i, kernel in enumerate(self._kernels): + if i == 0: + if grayscale: + inp_dim = 1 + else: + inp_dim = 3 + else: + inp_dim = 2 ** (i - 1) * self._depth + depth = 2**i * self._depth + layers.append( + Conv2dSame( + in_channels=inp_dim, + out_channels=depth, + kernel_size=(kernel, kernel), + stride=(2, 2), + ) + ) + h, w = h // 2, w // 2 + # layers.append(norm([depth, h, w])) + layers.append(act()) + self.layers = nn.Sequential(*layers) + + def __call__(self, obs): + x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:])) + x = x.permute(0, 3, 1, 2) + x = self.layers(x) + # prod: product of all elements + x = x.reshape([x.shape[0], np.prod(x.shape[1:])]) + shape = list(obs["image"].shape[:-3]) + [x.shape[-1]] + return x.reshape(shape) + + +class ConvDecoder(nn.Module): + def __init__( + self, + inp_depth, + depth=32, + act=nn.ELU, + norm=nn.LayerNorm, + shape=(3, 64, 64), + kernels=(3, 3, 3, 3), + ): + super(ConvDecoder, self).__init__() + self._inp_depth = inp_depth + self._act = act + self._norm = norm + self._depth = depth + self._shape = shape + self._kernels = kernels + self._embed_size = ( + (64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1) + ) + + self._linear_layer = nn.Linear(inp_depth, self._embed_size) + inp_dim = self._embed_size // 16 + + cnnt_layers = [] + h, w = 4, 4 + for i, kernel in enumerate(self._kernels): + depth = self._embed_size // 16 // (2 ** (i + 1)) + act = self._act + if i == len(self._kernels) - 1: + depth = self._shape[0] + act = None + if i != 0: + inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth + pad_h, outpad_h = calc_same_pad(k=kernel, s=2, d=1) + pad_w, outpad_w = calc_same_pad(k=kernel, s=2, d=1) + cnnt_layers.append( + nn.ConvTranspose2d( + inp_dim, + depth, + kernel, + 2, + padding=(pad_h, pad_w), + output_padding=(outpad_h, outpad_w), + ) + ) + h, w = h * 2, w * 2 + # cnnt_layers.append(norm([depth, h, w])) + if act is not None: + cnnt_layers.append(act()) + self._cnnt_layers = nn.Sequential(*cnnt_layers) + + def __call__(self, features, dtype=None): + x = self._linear_layer(features) + x = x.reshape([-1, 4, 4, self._embed_size // 16]) + x = x.permute(0, 3, 1, 2) + x = self._cnnt_layers(x) + mean = x.reshape(features.shape[:-1] + self._shape) + mean = mean.permute(0, 1, 3, 4, 2) + return tools.ContDist( + torchd.independent.Independent( + torchd.normal.Normal(mean, 1), len(self._shape) + ) + ) + + +class DenseHead(nn.Module): + def __init__( + self, + inp_dim, + shape, + layers, + units, + act=nn.ELU, + norm=nn.LayerNorm, + dist="normal", + std=1.0, + unimix_ratio=0.0, + ): + super(DenseHead, self).__init__() + self._shape = (shape,) if isinstance(shape, int) else shape + if len(self._shape) == 0: + self._shape = (1,) + self._layers = layers + self._units = units + self._act = act + self._norm = norm + self._dist = dist + self._std = std + self._unimix_ratio = unimix_ratio + + mean_layers = [] + for index in range(self._layers): + mean_layers.append(nn.Linear(inp_dim, self._units)) + mean_layers.append(norm(self._units)) + mean_layers.append(act()) + if index == 0: + inp_dim = self._units + mean_layers.append(nn.Linear(inp_dim, np.prod(self._shape))) + self._mean_layers = nn.Sequential(*mean_layers) + + if self._std == "learned": + self._std_layer = nn.Linear(self._units, np.prod(self._shape)) + + def __call__(self, features, dtype=None): + x = features + mean = self._mean_layers(x) + if self._std == "learned": + std = self._std_layer(x) + std = torch.softplus(std) + 0.01 + else: + std = self._std + if self._dist == "normal": + return tools.ContDist( + torchd.independent.Independent( + torchd.normal.Normal(mean, std), len(self._shape) + ) + ) + if self._dist == "huber": + return tools.ContDist( + torchd.independent.Independent( + tools.UnnormalizedHuber(mean, std, 1.0), len(self._shape) + ) + ) + if self._dist == "binary": + return tools.Bernoulli( + torchd.independent.Independent( + torchd.bernoulli.Bernoulli(logits=mean), len(self._shape) + ) + ) + if self._dist == "twohot": + return tools.TwoHotDist(logits=mean, unimix_ratio=self._unimix_ratio) + raise NotImplementedError(self._dist) + + +class ActionHead(nn.Module): + def __init__( + self, + inp_dim, + size, + layers, + units, + act=nn.ELU, + norm=nn.LayerNorm, + dist="trunc_normal", + init_std=0.0, + min_std=0.1, + action_disc=5, + temp=0.1, + outscale=0, + ): + super(ActionHead, self).__init__() + self._size = size + self._layers = layers + self._units = units + self._dist = dist + self._act = act + self._norm = norm + self._min_std = min_std + self._init_std = init_std + self._action_disc = action_disc + self._temp = temp() if callable(temp) else temp + self._outscale = outscale + + pre_layers = [] + for index in range(self._layers): + pre_layers.append(nn.Linear(inp_dim, self._units)) + pre_layers.append(norm(self._units)) + pre_layers.append(act()) + if index == 0: + inp_dim = self._units + self._pre_layers = nn.Sequential(*pre_layers) + + if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: + self._dist_layer = nn.Linear(self._units, 2 * self._size) + elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: + self._dist_layer = nn.Linear(self._units, self._size) + + def __call__(self, features, dtype=None): + x = features + x = self._pre_layers(x) + if self._dist == "tanh_normal": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + mean = torch.tanh(mean) + std = F.softplus(std + self._init_std) + self._min_std + dist = torchd.normal.Normal(mean, std) + dist = torchd.transformed_distribution.TransformedDistribution( + dist, tools.TanhBijector() + ) + dist = torchd.independent.Independent(dist, 1) + dist = tools.SampleDist(dist) + elif self._dist == "tanh_normal_5": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + mean = 5 * torch.tanh(mean / 5) + std = F.softplus(std + 5) + 5 + dist = torchd.normal.Normal(mean, std) + dist = torchd.transformed_distribution.TransformedDistribution( + dist, tools.TanhBijector() + ) + dist = torchd.independent.Independent(dist, 1) + dist = tools.SampleDist(dist) + elif self._dist == "normal": + x = self._dist_layer(x) + mean, std = torch.split(x, 2, -1) + std = F.softplus(std + self._init_std) + self._min_std + dist = torchd.normal.Normal(mean, std) + dist = tools.ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "normal_1": + x = self._dist_layer(x) + dist = torchd.normal.Normal(mean, 1) + dist = tools.ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "trunc_normal": + x = self._dist_layer(x) + mean, std = torch.split(x, [self._size] * 2, -1) + mean = torch.tanh(mean) + std = 2 * torch.sigmoid(std / 2) + self._min_std + dist = tools.SafeTruncatedNormal(mean, std, -1, 1) + dist = tools.ContDist(torchd.independent.Independent(dist, 1)) + elif self._dist == "onehot": + x = self._dist_layer(x) + dist = tools.OneHotDist(x) + elif self._dist == "onehot_gumble": + x = self._dist_layer(x) + temp = self._temp + dist = tools.ContDist(torchd.gumbel.Gumbel(x, 1 / temp)) + else: + raise NotImplementedError(self._dist) + return dist + + +class GRUCell(nn.Module): + def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1): + super(GRUCell, self).__init__() + self._inp_size = inp_size + self._size = size + self._act = act + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(inp_size + size, 3 * size, bias=norm is not None) + if norm: + self._norm = nn.LayerNorm(3 * size) + + @property + def state_size(self): + return self._size + + def forward(self, inputs, state): + state = state[0] # Keras wraps the state in a list. + parts = self._layer(torch.cat([inputs, state], -1)) + if self._norm: + parts = self._norm(parts) + reset, cand, update = torch.split(parts, [self._size] * 3, -1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, [output] + + +class Conv2dSame(torch.nn.Conv2d): + def calc_same_pad(self, i, k, s, d): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x): + ih, iw = x.size()[-2:] + pad_h = self.calc_same_pad( + i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0] + ) + pad_w = self.calc_same_pad( + i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1] + ) + + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + + ret = F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return ret + + +def calc_same_pad(k, s, d): + val = d * (k - 1) - s + 1 + pad = math.ceil(val / 2) + outpad = pad * 2 - val + return pad, outpad diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..30f5b3d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +torch==1.13.0 +numpy==1.20.1 +torchvision==0.14.0 +tensorboard==2.5.0 +pandas==1.2.4 +matplotlib==3.4.1 +ruamel.yaml==0.17.4 +gym[atari]==0.18.0 +moviepy==1.0.3 +einops==0.3.0 +protobuf==3.20.0 +dm_control==1.0.9 \ No newline at end of file diff --git a/tools.py b/tools.py new file mode 100644 index 0000000..13b351e --- /dev/null +++ b/tools.py @@ -0,0 +1,700 @@ +import datetime +import io +import json +import pathlib +import pickle +import re +import time +import uuid + +import numpy as np + +import torch +from torch import nn +from torch.nn import functional as F +from torch import distributions as torchd +from torch.utils.data import Dataset +from torch.utils.tensorboard import SummaryWriter + + +class RequiresGrad: + + def __init__(self, model): + self._model = model + + def __enter__(self): + self._model.requires_grad_(requires_grad=True) + + def __exit__(self, *args): + self._model.requires_grad_(requires_grad=False) + + +class TimeRecording: + + def __init__(self, comment): + self._comment = comment + + def __enter__(self): + self._st = torch.cuda.Event(enable_timing=True) + self._nd = torch.cuda.Event(enable_timing=True) + self._st.record() + + def __exit__(self, *args): + self._nd.record() + torch.cuda.synchronize() + print(self._comment, self._st.elapsed_time(self._nd)/1000) + + +class Logger: + + def __init__(self, logdir, step): + self._logdir = logdir + self._writer = SummaryWriter(log_dir=str(logdir), max_queue=1000) + self._last_step = None + self._last_time = None + self._scalars = {} + self._images = {} + self._videos = {} + self.step = step + + def scalar(self, name, value): + self._scalars[name] = float(value) + + def image(self, name, value): + self._images[name] = np.array(value) + + def video(self, name, value): + self._videos[name] = np.array(value) + + def write(self, fps=False): + scalars = list(self._scalars.items()) + if fps: + scalars.append(('fps', self._compute_fps(self.step))) + print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars)) + with (self._logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n') + for name, value in scalars: + self._writer.add_scalar('scalars/' + name, value, self.step) + for name, value in self._images.items(): + self._writer.add_image(name, value, self.step) + for name, value in self._videos.items(): + name = name if isinstance(name, str) else name.decode('utf-8') + if np.issubdtype(value.dtype, np.floating): + value = np.clip(255 * value, 0, 255).astype(np.uint8) + B, T, H, W, C = value.shape + value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W)) + self._writer.add_video(name, value, self.step, 16) + + self._writer.flush() + self._scalars = {} + self._images = {} + self._videos = {} + + def _compute_fps(self, step): + if self._last_step is None: + self._last_time = time.time() + self._last_step = step + return 0 + steps = step - self._last_step + duration = time.time() - self._last_time + self._last_time += duration + self._last_step = step + return steps / duration + + def offline_scalar(self, name, value, step): + self._writer.add_scalar('scalars/'+name, value, step) + + def offline_video(self, name, value, step): + if np.issubdtype(value.dtype, np.floating): + value = np.clip(255 * value, 0, 255).astype(np.uint8) + B, T, H, W, C = value.shape + value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W)) + self._writer.add_video(name, value, step, 16) + + +def simulate(agent, envs, steps=0, episodes=0, state=None): + # Initialize or unpack simulation state. + if state is None: + step, episode = 0, 0 + done = np.ones(len(envs), np.bool) + length = np.zeros(len(envs), np.int32) + obs = [None] * len(envs) + agent_state = None + reward = [0]*len(envs) + else: + step, episode, done, length, obs, agent_state, reward = state + while (steps and step < steps) or (episodes and episode < episodes): + # Reset envs if necessary. + if done.any(): + indices = [index for index, d in enumerate(done) if d] + results = [envs[i].reset() for i in indices] + for index, result in zip(indices, results): + obs[index] = result + reward = [reward[i]*(1-done[i]) for i in range(len(envs))] + # Step agents. + obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} + action, agent_state = agent(obs, done, agent_state, reward) + if isinstance(action, dict): + action = [ + {k: np.array(action[k][i].detach().cpu()) for k in action} + for i in range(len(envs))] + else: + action = np.array(action) + assert len(action) == len(envs) + # Step envs. + results = [e.step(a) for e, a in zip(envs, action)] + obs, reward, done = zip(*[p[:3] for p in results]) + obs = list(obs) + reward = list(reward) + done = np.stack(done) + episode += int(done.sum()) + length += 1 + step += (done * length).sum() + length *= (1 - done) + + return (step - steps, episode - episodes, done, length, obs, agent_state, reward) + + +def save_episodes(directory, episodes): + directory = pathlib.Path(directory).expanduser() + directory.mkdir(parents=True, exist_ok=True) + timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') + filenames = [] + for episode in episodes: + identifier = str(uuid.uuid4().hex) + length = len(episode['reward']) + filename = directory / f'{timestamp}-{identifier}-{length}.npz' + with io.BytesIO() as f1: + np.savez_compressed(f1, **episode) + f1.seek(0) + with filename.open('wb') as f2: + f2.write(f1.read()) + filenames.append(filename) + return filenames + + +def from_generator(generator, batch_size): + while True: + batch = [] + for _ in range(batch_size): + batch.append(next(generator)) + data = {} + for key in batch[0].keys(): + data[key] = [] + for i in range(batch_size): + data[key].append(batch[i][key]) + data[key] = np.stack(data[key], 0) + yield data + + +def sample_episodes(episodes, length=None, balance=False, seed=0): + random = np.random.RandomState(seed) + while True: + episode = random.choice(list(episodes.values())) + if length: + total = len(next(iter(episode.values()))) + available = total - length + if available < 1: + print(f'Skipped short episode of length {available}.') + continue + if balance: + index = min(random.randint(0, total), available) + else: + index = int(random.randint(0, available + 1)) + episode = {k: v[index: index + length] for k, v in episode.items()} + yield episode + + +def load_episodes(directory, limit=None, reverse=True): + directory = pathlib.Path(directory).expanduser() + episodes = {} + total = 0 + if reverse: + for filename in reversed(sorted(directory.glob('*.npz'))): + try: + with filename.open('rb') as f: + episode = np.load(f) + episode = {k: episode[k] for k in episode.keys()} + except Exception as e: + print(f'Could not load episode: {e}') + continue + episodes[str(filename)] = episode + total += len(episode['reward']) - 1 + if limit and total >= limit: + break + else: + for filename in sorted(directory.glob('*.npz')): + try: + with filename.open('rb') as f: + episode = np.load(f) + episode = {k: episode[k] for k in episode.keys()} + except Exception as e: + print(f'Could not load episode: {e}') + continue + episodes[str(filename)] = episode + total += len(episode['reward']) - 1 + if limit and total >= limit: + break + return episodes + + +class SampleDist: + + def __init__(self, dist, samples=100): + self._dist = dist + self._samples = samples + + @property + def name(self): + return 'SampleDist' + + def __getattr__(self, name): + return getattr(self._dist, name) + + def mean(self): + samples = self._dist.sample(self._samples) + return torch.mean(samples, 0) + + def mode(self): + sample = self._dist.sample(self._samples) + logprob = self._dist.log_prob(sample) + return sample[torch.argmax(logprob)][0] + + def entropy(self): + sample = self._dist.sample(self._samples) + logprob = self.log_prob(sample) + return -torch.mean(logprob, 0) + + +class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): + + def __init__(self, logits=None, probs=None, unimix_ratio=0.0): + if logits is not None and probs is None and unimix_ratio > 0.0: + probs = F.softmax(logits, dim=-1) + probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1] + logits = None + super().__init__(logits=logits, probs=probs) + + def mode(self): + _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) + return _mode.detach() + super().logits - super().logits.detach() + + def sample(self, sample_shape=(), seed=None): + if seed is not None: + raise ValueError('need to check') + sample = super().sample(sample_shape) + probs = super().probs + while len(probs.shape) < len(sample.shape): + probs = probs[None] + sample += probs - probs.detach() + return sample + + +class TwoHotDist(torchd.one_hot_categorical.OneHotCategorical): + + def __init__(self, logits=None, probs=None, unimix_ratio=0.0, device='cuda'): + if logits is not None and probs is None and unimix_ratio > 0.0: + probs = F.softmax(logits, dim=-1) + probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1] + logits = None + super().__init__(logits=logits, probs=probs) + + self.buckets = torch.linspace(-20.0, 20.0, steps=255).to(device) + self.width = (self.buckets[-1] - self.buckets[0]) / 255 + + def mode(self): + _mode = super().probs * self.buckets + return torch.sum(_mode, dim=-1, keepdim=True) + + # Inside OneHotCategorical, log_prob is calculated using only max element in targets + def log_prob(self, x): + # x(time, batch, 1) + x = (x - self.buckets[0]) / self.width + lower_indices = (x).to(torch.int64) + # lower_indices is idnside 0 ~ len(buckets)-2 + lower_indices = torch.clip(lower_indices, max=len(self.buckets)-2) + # upper_indices is inside 1 ~ len(buckets)-1 + upper_indices = lower_indices + 1 + lower_weight = torch.abs(x - upper_indices).squeeze(-1) + upper_weight = torch.abs(x - lower_indices).squeeze(-1) + # (time, batch, 1) -> (time, batch, bucket_class) + lower_log_prob = super().log_prob(F.one_hot(lower_indices.squeeze(-1), num_classes=len(self.buckets))) + upper_log_prob = super().log_prob(F.one_hot(upper_indices.squeeze(-1), num_classes=len(self.buckets))) + + # label = lower_log_prob * lower_weight + upper_log_prob * upper_weight + # # (time, batch, bucket_class) -> (time, batch) + # cross_entropy = torch.sum(torch.log(super().probs) * label, axis=-1) + + return lower_weight * lower_log_prob + upper_weight * upper_log_prob + +class ContDist: + + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean + + def __getattr__(self, name): + return getattr(self._dist, name) + + def entropy(self): + return self._dist.entropy() + + def mode(self): + return self._dist.mean + + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) + + def log_prob(self, x): + return self._dist.log_prob(x) + + +class Bernoulli: + + def __init__(self, dist=None): + super().__init__() + self._dist = dist + self.mean = dist.mean + + def __getattr__(self, name): + return getattr(self._dist, name) + + def entropy(self): + return self._dist.entropy() + + def mode(self): + _mode = torch.round(self._dist.mean) + return _mode.detach() +self._dist.mean - self._dist.mean.detach() + + def sample(self, sample_shape=()): + return self._dist.rsample(sample_shape) + + def log_prob(self, x): + _logits = self._dist.base_dist.logits + log_probs0 = -F.softplus(_logits) + log_probs1 = -F.softplus(-_logits) + + return log_probs0 * (1-x) + log_probs1 * x + + +class UnnormalizedHuber(torchd.normal.Normal): + + def __init__(self, loc, scale, threshold=1, **kwargs): + super().__init__(loc, scale, **kwargs) + self._threshold = threshold + + def log_prob(self, event): + return -(torch.sqrt( + (event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) + + def mode(self): + return self.mean + + +class SafeTruncatedNormal(torchd.normal.Normal): + + def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): + super().__init__(loc, scale) + self._low = low + self._high = high + self._clip = clip + self._mult = mult + + def sample(self, sample_shape): + event = super().sample(sample_shape) + if self._clip: + clipped = torch.clip(event, self._low + self._clip, + self._high - self._clip) + event = event - event.detach() + clipped.detach() + if self._mult: + event *= self._mult + return event + + +class TanhBijector(torchd.Transform): + + def __init__(self, validate_args=False, name='tanh'): + super().__init__() + + def _forward(self, x): + return torch.tanh(x) + + def _inverse(self, y): + y = torch.where( + (torch.abs(y) <= 1.), + torch.clamp(y, -0.99999997, 0.99999997), y) + y = torch.atanh(y) + return y + + def _forward_log_det_jacobian(self, x): + log2 = torch.math.log(2.0) + return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) + + +def static_scan_for_lambda_return(fn, inputs, start): + last = start + indices = range(inputs[0].shape[0]) + indices = reversed(indices) + flag = True + for index in indices: + inp = lambda x: (_input[x] for _input in inputs) + last = fn(last, *inp(index)) + if flag: + outputs = last + flag = False + else: + outputs = torch.cat([outputs, last], dim=-1) + outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1]) + outputs = torch.unbind(outputs, dim=0) + return outputs + + +def lambda_return( + reward, value, pcont, bootstrap, lambda_, axis): + # Setting lambda=1 gives a discounted Monte Carlo return. + # Setting lambda=0 gives a fixed 1-step return. + #assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) + assert len(reward.shape) == len(value.shape), (reward.shape, value.shape) + if isinstance(pcont, (int, float)): + pcont = pcont * torch.ones_like(reward) + dims = list(range(len(reward.shape))) + dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] + if axis != 0: + reward = reward.permute(dims) + value = value.permute(dims) + pcont = pcont.permute(dims) + if bootstrap is None: + bootstrap = torch.zeros_like(value[-1]) + next_values = torch.cat([value[1:], bootstrap[None]], 0) + inputs = reward + pcont * next_values * (1 - lambda_) + #returns = static_scan( + # lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, + # (inputs, pcont), bootstrap, reverse=True) + # reimplement to optimize performance + returns = static_scan_for_lambda_return( + lambda agg, cur0, cur1: cur0 + cur1 * lambda_ * agg, + (inputs, pcont), bootstrap) + if axis != 0: + returns = returns.permute(dims) + return returns + + +class Optimizer(): + + def __init__( + self, name, parameters, lr, eps=1e-4, clip=None, wd=None, wd_pattern=r'.*', + opt='adam', use_amp=False): + assert 0 <= wd < 1 + assert not clip or 1 <= clip + self._name = name + self._parameters = parameters + self._clip = clip + self._wd = wd + self._wd_pattern = wd_pattern + self._opt = { + 'adam': lambda: torch.optim.Adam(parameters, + lr=lr, + eps=eps), + 'nadam': lambda: NotImplemented( + f'{config.opt} is not implemented'), + 'adamax': lambda: torch.optim.Adamax(parameters, + lr=lr, + eps=eps), + 'sgd': lambda: torch.optim.SGD(parameters, + lr=lr), + 'momentum': lambda: torch.optim.SGD(parameters, + lr=lr, + momentum=0.9), + }[opt]() + self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + + def __call__(self, loss, params, retain_graph=False): + assert len(loss.shape) == 0, loss.shape + metrics = {} + metrics[f'{self._name}_loss'] = loss.detach().cpu().numpy() + self._scaler.scale(loss).backward() + self._scaler.unscale_(self._opt) + #loss.backward(retain_graph=retain_graph) + norm = torch.nn.utils.clip_grad_norm_(params, self._clip) + if self._wd: + self._apply_weight_decay(params) + self._scaler.step(self._opt) + self._scaler.update() + #self._opt.step() + self._opt.zero_grad() + metrics[f'{self._name}_grad_norm'] = norm.item() + return metrics + + def _apply_weight_decay(self, varibs): + nontrivial = (self._wd_pattern != r'.*') + if nontrivial: + raise NotImplementedError + for var in varibs: + var.data = (1 - self._wd) * var.data + + +def args_type(default): + def parse_string(x): + if default is None: + return x + if isinstance(default, bool): + return bool(['False', 'True'].index(x)) + if isinstance(default, int): + return float(x) if ('e' in x or '.' in x) else int(x) + if isinstance(default, (list, tuple)): + return tuple(args_type(default[0])(y) for y in x.split(',')) + return type(default)(x) + def parse_object(x): + if isinstance(default, (list, tuple)): + return tuple(x) + return x + return lambda x: parse_string(x) if isinstance(x, str) else parse_object(x) + + +def static_scan(fn, inputs, start): + last = start + indices = range(inputs[0].shape[0]) + flag = True + for index in indices: + inp = lambda x: (_input[x] for _input in inputs) + last = fn(last, *inp(index)) + if flag: + if type(last) == type({}): + outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} + else: + outputs = [] + for _last in last: + if type(_last) == type({}): + outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) + else: + outputs.append(_last.clone().unsqueeze(0)) + flag = False + else: + if type(last) == type({}): + for key in last.keys(): + outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) + else: + for j in range(len(outputs)): + if type(last[j]) == type({}): + for key in last[j].keys(): + outputs[j][key] = torch.cat([outputs[j][key], + last[j][key].unsqueeze(0)], dim=0) + else: + outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) + if type(last) == type({}): + outputs = [outputs] + return outputs + + +# Original version +#def static_scan2(fn, inputs, start, reverse=False): +# last = start +# outputs = [[] for _ in range(len([start] if type(start)==type({}) else start))] +# indices = range(inputs[0].shape[0]) +# if reverse: +# indices = reversed(indices) +# for index in indices: +# inp = lambda x: (_input[x] for _input in inputs) +# last = fn(last, *inp(index)) +# [o.append(l) for o, l in zip(outputs, [last] if type(last)==type({}) else last)] +# if reverse: +# outputs = [list(reversed(x)) for x in outputs] +# res = [[]] * len(outputs) +# for i in range(len(outputs)): +# if type(outputs[i][0]) == type({}): +# _res = {} +# for key in outputs[i][0].keys(): +# _res[key] = [] +# for j in range(len(outputs[i])): +# _res[key].append(outputs[i][j][key]) +# #_res[key] = torch.stack(_res[key], 0) +# _res[key] = faster_stack(_res[key], 0) +# else: +# _res = outputs[i] +# #_res = torch.stack(_res, 0) +# _res = faster_stack(_res, 0) +# res[i] = _res +# return res + + +class Every: + + def __init__(self, every): + self._every = every + self._last = None + + def __call__(self, step): + if not self._every: + return False + if self._last is None: + self._last = step + return True + if step >= self._last + self._every: + self._last += self._every + return True + return False + + +class Once: + + def __init__(self): + self._once = True + + def __call__(self): + if self._once: + self._once = False + return True + return False + + +class Until: + + def __init__(self, until): + self._until = until + + def __call__(self, step): + if not self._until: + return True + return step < self._until + + +def schedule(string, step): + try: + return float(string) + except ValueError: + match = re.match(r'linear\((.+),(.+),(.+)\)', string) + if match: + initial, final, duration = [float(group) for group in match.groups()] + mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0] + return (1 - mix) * initial + mix * final + match = re.match(r'warmup\((.+),(.+)\)', string) + if match: + warmup, value = [float(group) for group in match.groups()] + scale = torch.clip(step / warmup, 0, 1) + return scale * value + match = re.match(r'exp\((.+),(.+),(.+)\)', string) + if match: + initial, final, halflife = [float(group) for group in match.groups()] + return (initial - final) * 0.5 ** (step / halflife) + final + match = re.match(r'horizon\((.+),(.+),(.+)\)', string) + if match: + initial, final, duration = [float(group) for group in match.groups()] + mix = torch.clip(step / duration, 0, 1) + horizon = (1 - mix) * initial + mix * final + return 1 - 1 / horizon + raise NotImplementedError(string) + +def weight_init(m): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + gain = nn.init.calculate_gain('relu') + nn.init.orthogonal_(m.weight.data, gain) + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.LayerNorm): + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) \ No newline at end of file diff --git a/wrappers.py b/wrappers.py new file mode 100644 index 0000000..38448cc --- /dev/null +++ b/wrappers.py @@ -0,0 +1,419 @@ +import threading + +import gym +import numpy as np + + +class DeepMindLabyrinth(object): + ACTION_SET_DEFAULT = ( + (0, 0, 0, 1, 0, 0, 0), # Forward + (0, 0, 0, -1, 0, 0, 0), # Backward + (0, 0, -1, 0, 0, 0, 0), # Strafe Left + (0, 0, 1, 0, 0, 0, 0), # Strafe Right + (-20, 0, 0, 0, 0, 0, 0), # Look Left + (20, 0, 0, 0, 0, 0, 0), # Look Right + (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward + (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward + (0, 0, 0, 0, 1, 0, 0), # Fire + ) + + ACTION_SET_MEDIUM = ( + (0, 0, 0, 1, 0, 0, 0), # Forward + (0, 0, 0, -1, 0, 0, 0), # Backward + (0, 0, -1, 0, 0, 0, 0), # Strafe Left + (0, 0, 1, 0, 0, 0, 0), # Strafe Right + (-20, 0, 0, 0, 0, 0, 0), # Look Left + (20, 0, 0, 0, 0, 0, 0), # Look Right + (0, 0, 0, 0, 0, 0, 0), # Idle. + ) + + ACTION_SET_SMALL = ( + (0, 0, 0, 1, 0, 0, 0), # Forward + (-20, 0, 0, 0, 0, 0, 0), # Look Left + (20, 0, 0, 0, 0, 0, 0), # Look Right + ) + + def __init__( + self, + level, + mode, + action_repeat=4, + render_size=(64, 64), + action_set=ACTION_SET_DEFAULT, + level_cache=None, + seed=None, + runfiles_path=None, + ): + assert mode in ("train", "test") + import deepmind_lab + + if runfiles_path: + print("Setting DMLab runfiles path:", runfiles_path) + deepmind_lab.set_runfiles_path(runfiles_path) + self._config = {} + self._config["width"] = render_size[0] + self._config["height"] = render_size[1] + self._config["logLevel"] = "WARN" + if mode == "test": + self._config["allowHoldOutLevels"] = "true" + self._config["mixerSeed"] = 0x600D5EED + self._action_repeat = action_repeat + self._random = np.random.RandomState(seed) + self._env = deepmind_lab.Lab( + level="contributed/dmlab30/" + level, + observations=["RGB_INTERLEAVED"], + config={k: str(v) for k, v in self._config.items()}, + level_cache=level_cache, + ) + self._action_set = action_set + self._last_image = None + self._done = True + + @property + def observation_space(self): + shape = (self._config["height"], self._config["width"], 3) + space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) + return gym.spaces.Dict({"image": space}) + + @property + def action_space(self): + return gym.spaces.Discrete(len(self._action_set)) + + def reset(self): + self._done = False + self._env.reset(seed=self._random.randint(0, 2**31 - 1)) + obs = self._get_obs() + return obs + + def step(self, action): + raw_action = np.array(self._action_set[action], np.intc) + reward = self._env.step(raw_action, num_steps=self._action_repeat) + self._done = not self._env.is_running() + obs = self._get_obs() + return obs, reward, self._done, {} + + def render(self, *args, **kwargs): + if kwargs.get("mode", "rgb_array") != "rgb_array": + raise ValueError("Only render mode 'rgb_array' is supported.") + del args # Unused + del kwargs # Unused + return self._last_image + + def close(self): + self._env.close() + + def _get_obs(self): + if self._done: + image = 0 * self._last_image + else: + image = self._env.observations()["RGB_INTERLEAVED"] + self._last_image = image + return {"image": image} + + +class DeepMindControl: + def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): + domain, task = name.split("_", 1) + if domain == "cup": # Only domain with multiple words. + domain = "ball_in_cup" + if isinstance(domain, str): + from dm_control import suite + + self._env = suite.load(domain, task) + else: + assert task is None + self._env = domain() + self._action_repeat = action_repeat + self._size = size + if camera is None: + camera = dict(quadruped=2).get(domain, 0) + self._camera = camera + + @property + def observation_space(self): + spaces = {} + for key, value in self._env.observation_spec().items(): + spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32) + spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8) + return gym.spaces.Dict(spaces) + + @property + def action_space(self): + spec = self._env.action_spec() + return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) + + def step(self, action): + assert np.isfinite(action).all(), action + reward = 0 + for _ in range(self._action_repeat): + time_step = self._env.step(action) + reward += time_step.reward or 0 + if time_step.last(): + break + obs = dict(time_step.observation) + obs["image"] = self.render() + done = time_step.last() + info = {"discount": np.array(time_step.discount, np.float32)} + return obs, reward, done, info + + def reset(self): + time_step = self._env.reset() + obs = dict(time_step.observation) + obs["image"] = self.render() + return obs + + def render(self, *args, **kwargs): + if kwargs.get("mode", "rgb_array") != "rgb_array": + raise ValueError("Only render mode 'rgb_array' is supported.") + return self._env.physics.render(*self._size, camera_id=self._camera) + + +class Atari: + LOCK = threading.Lock() + + def __init__( + self, + name, + action_repeat=4, + size=(84, 84), + grayscale=True, + noops=30, + life_done=False, + sticky_actions=True, + all_actions=False, + ): + assert size[0] == size[1] + import gym.wrappers + import gym.envs.atari + + if name == "james_bond": + name = "jamesbond" + with self.LOCK: + env = gym.envs.atari.AtariEnv( + game=name, + obs_type="image", + frameskip=1, + repeat_action_probability=0.25 if sticky_actions else 0.0, + full_action_space=all_actions, + ) + # Avoid unnecessary rendering in inner env. + env._get_obs = lambda: None + # Tell wrapper that the inner env has no action repeat. + env.spec = gym.envs.registration.EnvSpec("NoFrameskip-v0") + env = gym.wrappers.AtariPreprocessing( + env, noops, action_repeat, size[0], life_done, grayscale + ) + self._env = env + self._grayscale = grayscale + + @property + def observation_space(self): + return gym.spaces.Dict( + { + "image": self._env.observation_space, + "ram": gym.spaces.Box(0, 255, (128,), np.uint8), + } + ) + + @property + def action_space(self): + return self._env.action_space + + def close(self): + return self._env.close() + + def reset(self): + with self.LOCK: + image = self._env.reset() + if self._grayscale: + image = image[..., None] + obs = {"image": image, "ram": self._env.env._get_ram()} + return obs + + def step(self, action): + image, reward, done, info = self._env.step(action) + if self._grayscale: + image = image[..., None] + obs = {"image": image, "ram": self._env.env._get_ram()} + return obs, reward, done, info + + def render(self, mode): + return self._env.render(mode) + + +class CollectDataset: + def __init__(self, env, callbacks=None, precision=32): + self._env = env + self._callbacks = callbacks or () + self._precision = precision + self._episode = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = {k: self._convert(v) for k, v in obs.items()} + transition = obs.copy() + if isinstance(action, dict): + transition.update(action) + else: + transition["action"] = action + transition["reward"] = reward + transition["discount"] = info.get("discount", np.array(1 - float(done))) + self._episode.append(transition) + if done: + for key, value in self._episode[1].items(): + if key not in self._episode[0]: + self._episode[0][key] = 0 * value + episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} + episode = {k: self._convert(v) for k, v in episode.items()} + info["episode"] = episode + for callback in self._callbacks: + callback(episode) + return obs, reward, done, info + + def reset(self): + obs = self._env.reset() + transition = obs.copy() + # Missing keys will be filled with a zeroed out version of the first + # transition, because we do not know what action information the agent will + # pass yet. + transition["reward"] = 0.0 + transition["discount"] = 1.0 + self._episode = [transition] + return obs + + def _convert(self, value): + value = np.array(value) + if np.issubdtype(value.dtype, np.floating): + dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] + elif np.issubdtype(value.dtype, np.signedinteger): + dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] + elif np.issubdtype(value.dtype, np.uint8): + dtype = np.uint8 + else: + raise NotImplementedError(value.dtype) + return value.astype(dtype) + + +class TimeLimit: + def __init__(self, env, duration): + self._env = env + self._duration = duration + self._step = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + assert self._step is not None, "Must reset environment." + obs, reward, done, info = self._env.step(action) + self._step += 1 + if self._step >= self._duration: + done = True + if "discount" not in info: + info["discount"] = np.array(1.0).astype(np.float32) + self._step = None + return obs, reward, done, info + + def reset(self): + self._step = 0 + return self._env.reset() + + +class NormalizeActions: + def __init__(self, env): + self._env = env + self._mask = np.logical_and( + np.isfinite(env.action_space.low), np.isfinite(env.action_space.high) + ) + self._low = np.where(self._mask, env.action_space.low, -1) + self._high = np.where(self._mask, env.action_space.high, 1) + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def action_space(self): + low = np.where(self._mask, -np.ones_like(self._low), self._low) + high = np.where(self._mask, np.ones_like(self._low), self._high) + return gym.spaces.Box(low, high, dtype=np.float32) + + def step(self, action): + original = (action + 1) / 2 * (self._high - self._low) + self._low + original = np.where(self._mask, original, action) + return self._env.step(original) + + +class OneHotAction: + def __init__(self, env): + assert isinstance(env.action_space, gym.spaces.Discrete) + self._env = env + self._random = np.random.RandomState() + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def action_space(self): + shape = (self._env.action_space.n,) + space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) + space.sample = self._sample_action + space.discrete = True + return space + + def step(self, action): + index = np.argmax(action).astype(int) + reference = np.zeros_like(action) + reference[index] = 1 + if not np.allclose(reference, action): + raise ValueError(f"Invalid one-hot action:\n{action}") + return self._env.step(index) + + def reset(self): + return self._env.reset() + + def _sample_action(self): + actions = self._env.action_space.n + index = self._random.randint(0, actions) + reference = np.zeros(actions, dtype=np.float32) + reference[index] = 1.0 + return reference + + +class RewardObs: + def __init__(self, env): + self._env = env + + def __getattr__(self, name): + return getattr(self._env, name) + + @property + def observation_space(self): + spaces = self._env.observation_space.spaces + assert "reward" not in spaces + spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) + return gym.spaces.Dict(spaces) + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs["reward"] = reward + return obs, reward, done, info + + def reset(self): + obs = self._env.reset() + obs["reward"] = 0.0 + return obs + + +class SelectAction: + def __init__(self, env, key): + self._env = env + self._key = key + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + return self._env.step(action[self._key])