diff --git a/dreamer.py b/dreamer.py index 7d41a55..c7fa85e 100644 --- a/dreamer.py +++ b/dreamer.py @@ -36,7 +36,8 @@ class Dreamer(nn.Module): self._should_reset = tools.Every(config.reset_every) self._should_expl = tools.Until(int(config.expl_until / config.action_repeat)) self._metrics = {} - self._step = count_steps(config.traindir) + # this is update step + self._step = logger.step // config.action_repeat self._update_count = 0 # Schedules. config.actor_entropy = lambda x=config.actor_entropy: tools.schedule( @@ -226,82 +227,23 @@ def make_env(config, logger, mode, train_eps, eval_eps): raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) env = wrappers.SelectAction(env, key="action") - if (mode == "train") or (mode == "eval"): - callbacks = [ - functools.partial( - ProcessEpisodeWrap.process_episode, - config, - logger, - mode, - train_eps, - eval_eps, - ) - ] - env = wrappers.CollectDataset(env, mode, train_eps, callbacks=callbacks) + env = wrappers.UUID(env) + # if (mode == "train") or (mode == "eval"): + # callbacks = [ + # functools.partial( + # ProcessEpisodeWrap.process_episode, + # config, + # logger, + # mode, + # train_eps, + # eval_eps, + # ) + # ] + # env = wrappers.CollectDataset(env, mode, train_eps, callbacks=callbacks) env = wrappers.RewardObs(env) return env -class ProcessEpisodeWrap: - eval_scores = [] - eval_lengths = [] - last_step_at_eval = -1 - eval_done = False - - @classmethod - def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode): - directory = dict(train=config.traindir, eval=config.evaldir)[mode] - cache = dict(train=train_eps, eval=eval_eps)[mode] - # this saved episodes is given as train_eps or eval_eps from next call - filename = tools.save_episodes(directory, [episode])[0] - length = len(episode["reward"]) - 1 - score = float(episode["reward"].astype(np.float64).sum()) - video = episode["image"] - # add new episode - cache[str(filename)] = episode - if mode == "train": - step_in_dataset = 0 - for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): - if ( - not config.dataset_size - or step_in_dataset + (len(ep["reward"]) - 1) <= config.dataset_size - ): - step_in_dataset += len(ep["reward"]) - 1 - else: - del cache[key] - logger.scalar("dataset_size", step_in_dataset) - elif mode == "eval": - # keep only last item for saving memory - while len(cache) > 1: - # FIFO - cache.popitem() - # start counting scores for evaluation - if cls.last_step_at_eval != logger.step: - cls.eval_scores = [] - cls.eval_lengths = [] - cls.eval_done = False - cls.last_step_at_eval = logger.step - - cls.eval_scores.append(score) - cls.eval_lengths.append(length) - # ignore if number of eval episodes exceeds eval_episode_num - if len(cls.eval_scores) < config.eval_episode_num or cls.eval_done: - return - score = sum(cls.eval_scores) / len(cls.eval_scores) - length = sum(cls.eval_lengths) / len(cls.eval_lengths) - episode_num = len(cls.eval_scores) - logger.video(f"{mode}_policy", video[None]) - cls.eval_done = True - - print(f"{mode.title()} episode has {length} steps and return {score:.1f}.") - logger.scalar(f"{mode}_return", score) - logger.scalar(f"{mode}_length", length) - logger.scalar( - f"{mode}_episodes", len(cache) if mode == "train" else episode_num - ) - logger.write(step=logger.step) - - def main(config): logdir = pathlib.Path(config.logdir).expanduser() config.traindir = config.traindir or logdir / "train_eps" @@ -316,6 +258,7 @@ def main(config): config.traindir.mkdir(parents=True, exist_ok=True) config.evaldir.mkdir(parents=True, exist_ok=True) step = count_steps(config.traindir) + # step in logger is environmental step logger = tools.Logger(logdir, config.action_repeat * step) print("Create envs.") @@ -357,8 +300,9 @@ def main(config): logprob = random_actor.log_prob(action) return {"action": action, "logprob": logprob}, None - state = tools.simulate(random_agent, train_envs, prefill) - logger.step = config.action_repeat * count_steps(config.traindir) + state = tools.simulate(random_agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=prefill) + logger.step += prefill * config.action_repeat + print(f"Logger: ({logger.step} steps).") print("Simulate agent.") train_dataset = make_dataset(train_eps, config) @@ -379,12 +323,12 @@ def main(config): logger.write() print("Start evaluation.") eval_policy = functools.partial(agent, training=False) - tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num) + tools.simulate(eval_policy, eval_envs, eval_eps, config.evaldir, logger, is_eval=True, episodes=config.eval_episode_num) if config.video_pred_log: video_pred = agent._wm.video_pred(next(eval_dataset)) logger.video("eval_openl", to_np(video_pred)) print("Start training.") - state = tools.simulate(agent, train_envs, config.eval_every, state=state) + state = tools.simulate(agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=config.eval_every, state=state) torch.save(agent.state_dict(), logdir / "latest_model.pt") for env in train_envs + eval_envs: try: diff --git a/envs/wrappers.py b/envs/wrappers.py index da80837..80f4f19 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -1,89 +1,9 @@ +import datetime import gym import numpy as np import uuid -class CollectDataset: - def __init__( - self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32 - ): - self._env = env - self._callbacks = callbacks or () - self._precision = precision - self._episode = None - self._cache = dict(train=train_eps, eval=eval_eps)[mode] - self._temp_name = str(uuid.uuid4()) - - def __getattr__(self, name): - return getattr(self._env, name) - - def step(self, action): - obs, reward, done, info = self._env.step(action) - obs = {k: self._convert(v) for k, v in obs.items()} - transition = obs.copy() - if isinstance(action, dict): - transition.update(action) - else: - transition["action"] = action - transition["reward"] = reward - transition["discount"] = info.get("discount", np.array(1 - float(done))) - self._episode.append(transition) - self.add_to_cache(transition) - if done: - # detele transitions before whole episode is stored - del self._cache[self._temp_name] - self._temp_name = str(uuid.uuid4()) - for key, value in self._episode[1].items(): - if key not in self._episode[0]: - self._episode[0][key] = 0 * value - episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} - episode = {k: self._convert(v) for k, v in episode.items()} - info["episode"] = episode - for callback in self._callbacks: - callback(episode) - return obs, reward, done, info - - def reset(self): - obs = self._env.reset() - transition = obs.copy() - # Missing keys will be filled with a zeroed out version of the first - # transition, because we do not know what action information the agent will - # pass yet. - transition["reward"] = 0.0 - transition["discount"] = 1.0 - self._episode = [transition] - self.add_to_cache(transition) - return obs - - def add_to_cache(self, transition): - if self._temp_name not in self._cache: - self._cache[self._temp_name] = dict() - for key, val in transition.items(): - self._cache[self._temp_name][key] = [self._convert(val)] - else: - for key, val in transition.items(): - if key not in self._cache[self._temp_name]: - # fill missing data(action) - self._cache[self._temp_name][key] = [self._convert(0 * val)] - self._cache[self._temp_name][key].append(self._convert(val)) - else: - self._cache[self._temp_name][key].append(self._convert(val)) - - def _convert(self, value): - value = np.array(value) - if np.issubdtype(value.dtype, np.floating): - dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] - elif np.issubdtype(value.dtype, np.signedinteger): - dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] - elif np.issubdtype(value.dtype, np.uint8): - dtype = np.uint8 - elif np.issubdtype(value.dtype, bool): - dtype = bool - else: - raise NotImplementedError(value.dtype) - return value.astype(dtype) - - class TimeLimit: def __init__(self, env, duration): self._env = env @@ -208,3 +128,17 @@ class SelectAction: def step(self, action): return self._env.step(action[self._key]) + +class UUID: + def __init__(self, env): + self._env = env + timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + self.id = f"{timestamp}-{str(uuid.uuid4().hex)}" + + def __getattr__(self, name): + return getattr(self._env, name) + + def reset(self): + timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + self.id = f"{timestamp}-{str(uuid.uuid4().hex)}" + return self._env.reset() diff --git a/tools.py b/tools.py index f7ee068..c210791 100644 --- a/tools.py +++ b/tools.py @@ -1,6 +1,7 @@ import datetime import collections import io +import os import json import pathlib import re @@ -121,7 +122,7 @@ class Logger: self._writer.add_video(name, value, step, 16) -def simulate(agent, envs, steps=0, episodes=0, state=None): +def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, steps=0, episodes=0, state=None): # Initialize or unpack simulation state. if state is None: step, episode = 0, 0 @@ -137,6 +138,14 @@ def simulate(agent, envs, steps=0, episodes=0, state=None): if done.any(): indices = [index for index, d in enumerate(done) if d] results = [envs[i].reset() for i in indices] + for i in indices: + t = results[i].copy() + t = {k: convert(v) for k, v in t.items()} + # action will be added to transition in add_to_cache + t["reward"] = 0.0 + t["discount"] = 1.0 + # initial state should be added to cache + add_to_cache(cache, envs[i].id, t) for index, result in zip(indices, results): obs[index] = result reward = [reward[i] * (1 - done[i]) for i in range(len(envs))] @@ -161,26 +170,165 @@ def simulate(agent, envs, steps=0, episodes=0, state=None): length += 1 step += len(envs) length *= 1 - done + # Add to cache + for a, result, env in zip(action, results, envs): + o, r, d, info = result + o = {k: convert(v) for k, v in o.items()} + transition = o.copy() + if isinstance(a, dict): + transition.update(a) + else: + transition["action"] = a + transition["reward"] = r + transition["discount"] = info.get("discount", np.array(1 - float(d))) + add_to_cache(cache, env.id, transition) + if done.any(): + indices = [index for index, d in enumerate(done) if d] + # logging for done episode + for i in indices: + save_episodes(directory, {envs[i].id: cache[envs[i].id]}) + length = len(cache[envs[i].id]["reward"]) - 1 + score = float(np.array(cache[envs[i].id]["reward"]).sum()) + video = cache[envs[i].id]["image"] + if not is_eval: + step_in_dataset = erase_over_episodes(cache, limit) + logger.scalar(f"dataset_size", step_in_dataset) + logger.scalar(f"train_return", score) + logger.scalar(f"train_length", length) + logger.scalar(f"train_episodes", len(cache)) + logger.write(step=logger.step) + else: + if not 'eval_lengths' in locals(): + eval_lengths = [] + eval_scores = [] + eval_done = False + # start counting scores for evaluation + eval_scores.append(score) + eval_lengths.append(length) + + score = sum(eval_scores) / len(eval_scores) + length = sum(eval_lengths) / len(eval_lengths) + logger.video(f"eval_policy", np.array(video)[None]) + + if len(eval_scores) >= episodes and not eval_done: + logger.scalar(f"eval_return", score) + logger.scalar(f"eval_length", length) + logger.scalar(f"eval_episodes", len(eval_scores)) + logger.write(step=logger.step) + eval_done = True + if is_eval: + # keep only last item for saving memory. this cache is used for video_pred later + while len(cache) > 1: + # FIFO + cache.popitem(last=False) return (step - steps, episode - episodes, done, length, obs, agent_state, reward) +class CollectDataset: + def __init__( + self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32 + ): + self._env = env + self._callbacks = callbacks or () + self._precision = precision + self._episode = None + self._cache = dict(train=train_eps, eval=eval_eps)[mode] + self._temp_name = str(uuid.uuid4()) + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = {k: self._convert(v) for k, v in obs.items()} + transition = obs.copy() + if isinstance(action, dict): + transition.update(action) + else: + transition["action"] = action + transition["reward"] = reward + transition["discount"] = info.get("discount", np.array(1 - float(done))) + self._episode.append(transition) + self.add_to_cache(transition) + if done: + # detele transitions before whole episode is stored + del self._cache[self._temp_name] + self._temp_name = str(uuid.uuid4()) + for key, value in self._episode[1].items(): + if key not in self._episode[0]: + self._episode[0][key] = 0 * value + episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} + episode = {k: self._convert(v) for k, v in episode.items()} + info["episode"] = episode + for callback in self._callbacks: + callback(episode) + return obs, reward, done, info + + def reset(self): + obs = self._env.reset() + transition = obs.copy() + # Missing keys will be filled with a zeroed out version of the first + # transition, because we do not know what action information the agent will + # pass yet. + transition["reward"] = 0.0 + transition["discount"] = 1.0 + self._episode = [transition] + self.add_to_cache(transition) + return obs + +def add_to_cache(cache, id, transition): + if id not in cache: + cache[id] = dict() + for key, val in transition.items(): + cache[id][key] = [convert(val)] + else: + for key, val in transition.items(): + if key not in cache[id]: + # fill missing data(action, etc.) at second time + cache[id][key] = [convert(0 * val)] + cache[id][key].append(convert(val)) + else: + cache[id][key].append(convert(val)) + +def erase_over_episodes(cache, dataset_size): + step_in_dataset = 0 + for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): + if ( + not dataset_size + or step_in_dataset + (len(ep["reward"]) - 1) <= dataset_size + ): + step_in_dataset += len(ep["reward"]) - 1 + else: + del cache[key] + return step_in_dataset + +def convert(value, precision=32): + value = np.array(value) + if np.issubdtype(value.dtype, np.floating): + dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision] + elif np.issubdtype(value.dtype, np.signedinteger): + dtype = {16: np.int16, 32: np.int32, 64: np.int64}[precision] + elif np.issubdtype(value.dtype, np.uint8): + dtype = np.uint8 + elif np.issubdtype(value.dtype, bool): + dtype = bool + else: + raise NotImplementedError(value.dtype) + return value.astype(dtype) + def save_episodes(directory, episodes): directory = pathlib.Path(directory).expanduser() directory.mkdir(parents=True, exist_ok=True) - timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") - filenames = [] - for episode in episodes: - identifier = str(uuid.uuid4().hex) + for filename, episode in episodes.items(): length = len(episode["reward"]) - filename = directory / f"{timestamp}-{identifier}-{length}.npz" + filename = directory / f"{filename}-{length}.npz" with io.BytesIO() as f1: np.savez_compressed(f1, **episode) f1.seek(0) with filename.open("wb") as f2: f2.write(f1.read()) - filenames.append(filename) - return filenames + return True def from_generator(generator, batch_size): @@ -244,7 +392,8 @@ def load_episodes(directory, limit=None, reverse=True): except Exception as e: print(f"Could not load episode: {e}") continue - episodes[str(filename)] = episode + # extract only filename without extension + episodes[str(os.path.splitext(os.path.basename(filename))[0])] = episode total += len(episode["reward"]) - 1 if limit and total >= limit: break