dreamerv3-torch/dreamer.py

388 lines
14 KiB
Python
Raw Normal View History

2023-02-12 22:35:25 +09:00
import argparse
import functools
import os
import pathlib
import sys
os.environ["MUJOCO_GL"] = "osmesa"
2023-02-12 22:35:25 +09:00
import numpy as np
import ruamel.yaml as yaml
sys.path.append(str(pathlib.Path(__file__).parent))
import exploration as expl
import models
import tools
2023-04-15 23:16:43 +09:00
import envs.wrappers as wrappers
from parallel import Parallel, Damy
2023-02-12 22:35:25 +09:00
import torch
from torch import nn
from torch import distributions as torchd
2023-02-12 22:35:25 +09:00
to_np = lambda x: x.detach().cpu().numpy()
class Dreamer(nn.Module):
2023-05-14 23:38:46 +09:00
def __init__(self, obs_space, act_space, config, logger, dataset):
2023-02-12 22:35:25 +09:00
super(Dreamer, self).__init__()
self._config = config
self._logger = logger
self._should_log = tools.Every(config.log_every)
batch_steps = config.batch_size * config.batch_length
self._should_train = tools.Every(batch_steps / config.train_ratio)
2023-02-12 22:35:25 +09:00
self._should_pretrain = tools.Once()
self._should_reset = tools.Every(config.reset_every)
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
self._metrics = {}
# this is update step
self._step = logger.step // config.action_repeat
2023-04-15 15:25:25 +09:00
self._update_count = 0
2023-02-12 22:35:25 +09:00
# Schedules.
config.actor_entropy = lambda x=config.actor_entropy: tools.schedule(
x, self._step
)
config.actor_state_entropy = (
lambda x=config.actor_state_entropy: tools.schedule(x, self._step)
)
config.imag_gradient_mix = lambda x=config.imag_gradient_mix: tools.schedule(
x, self._step
)
self._dataset = dataset
2023-05-14 23:38:46 +09:00
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
2023-02-12 22:35:25 +09:00
self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad
)
2023-06-17 15:27:09 +09:00
if (
config.compile and os.name != "nt"
): # compilation is not supported on windows
self._wm = torch.compile(self._wm)
self._task_behavior = torch.compile(self._task_behavior)
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
2023-02-12 22:35:25 +09:00
self._expl_behavior = dict(
greedy=lambda: self._task_behavior,
2023-05-21 08:17:47 +09:00
random=lambda: expl.Random(config, act_space),
2023-02-12 22:35:25 +09:00
plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
2023-04-03 08:06:34 +09:00
)[config.expl_behavior]().to(self._config.device)
2023-02-12 22:35:25 +09:00
2023-07-22 20:53:43 +09:00
def __call__(self, obs, reset, state=None, training=True):
2023-02-12 22:35:25 +09:00
step = self._step
if self._should_reset(step):
state = None
if state is not None and reset.any():
mask = 1 - reset
for key in state[0].keys():
for i in range(state[0][key].shape[0]):
state[0][key][i] *= mask[i]
for i in range(len(state[1])):
state[1][i] *= mask[i]
2023-04-15 15:25:25 +09:00
if training:
2023-02-12 22:35:25 +09:00
steps = (
self._config.pretrain
if self._should_pretrain()
2023-04-15 15:25:25 +09:00
else self._should_train(step)
2023-02-12 22:35:25 +09:00
)
for _ in range(steps):
self._train(next(self._dataset))
2023-04-15 15:25:25 +09:00
self._update_count += 1
self._metrics["update_count"] = self._update_count
2023-02-12 22:35:25 +09:00
if self._should_log(step):
for name, values in self._metrics.items():
self._logger.scalar(name, float(np.mean(values)))
self._metrics[name] = []
2023-05-14 23:38:46 +09:00
if self._config.video_pred_log:
openl = self._wm.video_pred(next(self._dataset))
self._logger.video("train_openl", to_np(openl))
2023-02-12 22:35:25 +09:00
self._logger.write(fps=True)
policy_output, state = self._policy(obs, state, training)
if training:
self._step += len(reset)
self._logger.step = self._config.action_repeat * self._step
return policy_output, state
def _policy(self, obs, state, training):
if state is None:
batch_size = len(obs["image"])
latent = self._wm.dynamics.initial(len(obs["image"]))
action = torch.zeros((batch_size, self._config.num_actions)).to(
self._config.device
)
else:
latent, action = state
obs = self._wm.preprocess(obs)
embed = self._wm.encoder(obs)
2023-02-12 22:35:25 +09:00
latent, _ = self._wm.dynamics.obs_step(
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
2023-02-12 22:35:25 +09:00
)
if self._config.eval_state_mean:
latent["stoch"] = latent["mean"]
feat = self._wm.dynamics.get_feat(latent)
if not training:
actor = self._task_behavior.actor(feat)
action = actor.mode()
elif self._should_expl(self._step):
actor = self._expl_behavior.actor(feat)
action = actor.sample()
else:
actor = self._task_behavior.actor(feat)
action = actor.sample()
logprob = actor.log_prob(action)
latent = {k: v.detach() for k, v in latent.items()}
action = action.detach()
if self._config.actor_dist == "onehot_gumble":
action = torch.one_hot(
torch.argmax(action, dim=-1), self._config.num_actions
)
action = self._exploration(action, training)
policy_output = {"action": action, "logprob": logprob}
state = (latent, action)
return policy_output, state
def _exploration(self, action, training):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
if "onehot" in self._config.actor_dist:
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
else:
return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1)
def _train(self, data):
metrics = {}
post, context, mets = self._wm._train(data)
metrics.update(mets)
start = post
reward = lambda f, s, a: self._wm.heads["reward"](
self._wm.dynamics.get_feat(s)
).mode()
metrics.update(self._task_behavior._train(start, reward)[-1])
if self._config.expl_behavior != "greedy":
mets = self._expl_behavior.train(start, context, data)[-1]
metrics.update({"expl_" + key: value for key, value in mets.items()})
for name, value in metrics.items():
if not name in self._metrics.keys():
self._metrics[name] = [value]
else:
self._metrics[name].append(value)
def count_steps(folder):
return sum(int(str(n).split("-")[-1][:-4]) - 1 for n in folder.glob("*.npz"))
def make_dataset(episodes, config):
generator = tools.sample_episodes(episodes, config.batch_length)
2023-02-12 22:35:25 +09:00
dataset = tools.from_generator(generator, config.batch_size)
return dataset
2023-07-22 21:20:55 +09:00
def make_env(config, mode):
2023-02-12 22:35:25 +09:00
suite, task = config.task.split("_", 1)
if suite == "dmc":
2023-04-15 23:16:43 +09:00
import envs.dmc as dmc
env = dmc.DeepMindControl(task, config.action_repeat, config.size)
2023-02-12 22:35:25 +09:00
env = wrappers.NormalizeActions(env)
elif suite == "atari":
2023-04-15 23:16:43 +09:00
import envs.atari as atari
env = atari.Atari(
2023-02-12 22:35:25 +09:00
task,
config.action_repeat,
config.size,
2023-04-15 23:16:43 +09:00
gray=config.grayscale,
noops=config.noops,
lives=config.lives,
sticky=config.stickey,
actions=config.actions,
resize=config.resize,
2023-02-12 22:35:25 +09:00
)
env = wrappers.OneHotAction(env)
elif suite == "dmlab":
2023-04-15 23:16:43 +09:00
import envs.dmlab as dmlab
env = dmlab.DeepMindLabyrinth(
2023-02-12 22:35:25 +09:00
task, mode if "train" in mode else "test", config.action_repeat
)
env = wrappers.OneHotAction(env)
2023-06-18 09:16:32 +08:00
elif suite == "MemoryMaze":
from envs.memorymaze import MemoryMaze
2023-06-24 23:05:45 +09:00
2023-06-18 19:42:48 +09:00
env = MemoryMaze(task)
2023-06-18 09:16:32 +08:00
env = wrappers.OneHotAction(env)
2023-06-18 00:02:22 +09:00
elif suite == "crafter":
import envs.crafter as crafter
2023-07-23 22:02:06 +09:00
2023-06-18 00:02:22 +09:00
env = crafter.Crafter(task, config.size)
env = wrappers.OneHotAction(env)
2023-07-02 11:29:48 +09:00
elif suite == "minecraft":
import envs.minecraft as minecraft
2023-07-23 22:02:06 +09:00
2023-07-02 11:29:48 +09:00
env = minecraft.make_env(task, size=config.size, break_speed=config.break_speed)
env = wrappers.OneHotAction(env)
2023-02-12 22:35:25 +09:00
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)
env = wrappers.SelectAction(env, key="action")
env = wrappers.UUID(env)
2023-08-05 21:13:57 +09:00
if suite == "minecraft":
env = wrappers.RewardObs(env)
2023-02-12 22:35:25 +09:00
return env
def main(config):
logdir = pathlib.Path(config.logdir).expanduser()
config.traindir = config.traindir or logdir / "train_eps"
config.evaldir = config.evaldir or logdir / "eval_eps"
config.steps //= config.action_repeat
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.time_limit //= config.action_repeat
print("Logdir", logdir)
logdir.mkdir(parents=True, exist_ok=True)
config.traindir.mkdir(parents=True, exist_ok=True)
config.evaldir.mkdir(parents=True, exist_ok=True)
step = count_steps(config.traindir)
# step in logger is environmental step
2023-02-12 22:35:25 +09:00
logger = tools.Logger(logdir, config.action_repeat * step)
print("Create envs.")
if config.offline_traindir:
directory = config.offline_traindir.format(**vars(config))
else:
directory = config.traindir
train_eps = tools.load_episodes(directory, limit=config.dataset_size)
if config.offline_evaldir:
directory = config.offline_evaldir.format(**vars(config))
else:
directory = config.evaldir
eval_eps = tools.load_episodes(directory, limit=1)
2023-07-22 21:20:55 +09:00
make = lambda mode: make_env(config, mode)
2023-02-12 22:35:25 +09:00
train_envs = [make("train") for _ in range(config.envs)]
eval_envs = [make("eval") for _ in range(config.envs)]
if config.envs > 1:
train_envs = [Parallel(env, "process") for env in train_envs]
eval_envs = [Parallel(env, "process") for env in eval_envs]
else:
train_envs = [Damy(env) for env in train_envs]
eval_envs = [Damy(env) for env in eval_envs]
2023-02-12 22:35:25 +09:00
acts = train_envs[0].action_space
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
2023-07-02 11:51:11 +09:00
state = None
2023-02-12 22:35:25 +09:00
if not config.offline_traindir:
prefill = max(0, config.prefill - count_steps(config.traindir))
print(f"Prefill dataset ({prefill} steps).")
if hasattr(acts, "discrete"):
random_actor = tools.OneHotDist(
2023-04-15 23:16:43 +09:00
torch.zeros(config.num_actions).repeat(config.envs, 1)
2023-02-12 22:35:25 +09:00
)
else:
random_actor = torchd.independent.Independent(
torchd.uniform.Uniform(
torch.Tensor(acts.low).repeat(config.envs, 1),
torch.Tensor(acts.high).repeat(config.envs, 1),
2023-02-12 22:35:25 +09:00
),
1,
)
2023-07-22 20:53:43 +09:00
def random_agent(o, d, s):
2023-02-12 22:35:25 +09:00
action = random_actor.sample()
logprob = random_actor.log_prob(action)
return {"action": action, "logprob": logprob}, None
2023-07-23 22:02:06 +09:00
state = tools.simulate(
random_agent,
train_envs,
train_eps,
config.traindir,
logger,
limit=config.dataset_size,
steps=prefill,
)
logger.step += prefill * config.action_repeat
print(f"Logger: ({logger.step} steps).")
2023-02-12 22:35:25 +09:00
print("Simulate agent.")
train_dataset = make_dataset(train_eps, config)
eval_dataset = make_dataset(eval_eps, config)
2023-05-14 23:38:46 +09:00
agent = Dreamer(
train_envs[0].observation_space,
train_envs[0].action_space,
config,
logger,
train_dataset,
).to(config.device)
2023-02-12 22:35:25 +09:00
agent.requires_grad_(requires_grad=False)
if (logdir / "latest_model.pt").exists():
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
agent._should_pretrain._once = False
# make sure eval will be executed once after config.steps
while agent._step < config.steps + config.eval_every:
2023-02-12 22:35:25 +09:00
logger.write()
2023-08-05 21:13:57 +09:00
if config.eval_episode_num > 0:
print("Start evaluation.")
eval_policy = functools.partial(agent, training=False)
tools.simulate(
eval_policy,
eval_envs,
eval_eps,
config.evaldir,
logger,
is_eval=True,
episodes=config.eval_episode_num,
)
if config.video_pred_log:
video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
2023-02-12 22:35:25 +09:00
print("Start training.")
2023-07-23 22:02:06 +09:00
state = tools.simulate(
agent,
train_envs,
train_eps,
config.traindir,
logger,
limit=config.dataset_size,
steps=config.eval_every,
state=state,
)
2023-02-12 22:35:25 +09:00
torch.save(agent.state_dict(), logdir / "latest_model.pt")
for env in train_envs + eval_envs:
try:
env.close()
except Exception:
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser()
2023-05-14 23:38:46 +09:00
parser.add_argument("--configs", nargs="+")
2023-02-12 22:35:25 +09:00
args, remaining = parser.parse_known_args()
configs = yaml.safe_load(
(pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
)
2023-05-14 23:38:46 +09:00
def recursive_update(base, update):
for key, value in update.items():
if isinstance(value, dict) and key in base:
recursive_update(base[key], value)
else:
base[key] = value
name_list = ["defaults", *args.configs] if args.configs else ["defaults"]
2023-02-12 22:35:25 +09:00
defaults = {}
2023-05-14 23:38:46 +09:00
for name in name_list:
recursive_update(defaults, configs[name])
2023-02-12 22:35:25 +09:00
parser = argparse.ArgumentParser()
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value)
parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
main(parser.parse_args(remaining))