diff --git a/configs.yaml b/configs.yaml index 00446ae..02fb89c 100644 --- a/configs.yaml +++ b/configs.yaml @@ -70,7 +70,6 @@ defaults: batch_size: 16 batch_length: 64 train_ratio: 512 - train_steps: 1 pretrain: 100 model_lr: 1e-4 opt_eps: 1e-8 @@ -136,6 +135,5 @@ debug: debug: True pretrain: 1 prefill: 1 - train_steps: 1 batch_size: 10 batch_length: 20 diff --git a/dreamer.py b/dreamer.py index 25bf33d..877325b 100644 --- a/dreamer.py +++ b/dreamer.py @@ -39,6 +39,7 @@ class Dreamer(nn.Module): self._should_expl = tools.Until(int(config.expl_until / config.action_repeat)) self._metrics = {} self._step = count_steps(config.traindir) + self._update_count = 0 # Schedules. config.actor_entropy = lambda x=config.actor_entropy: tools.schedule( x, self._step @@ -75,14 +76,16 @@ class Dreamer(nn.Module): state[0][key][i] *= mask[i] for i in range(len(state[1])): state[1][i] *= mask[i] - if training and self._should_train(step): + if training: steps = ( self._config.pretrain if self._should_pretrain() - else self._config.train_steps + else self._should_train(step) ) for _ in range(steps): self._train(next(self._dataset)) + self._update_count += 1 + self._metrics["update_count"] = self._update_count if self._should_log(step): for name, values in self._metrics.items(): self._logger.scalar(name, float(np.mean(values))) @@ -227,6 +230,8 @@ def make_env(config, logger, mode, train_eps, eval_eps): class ProcessEpisodeWrap: eval_scores = [] eval_lengths = [] + last_step_at_eval = -1 + eval_done = False @classmethod def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode): @@ -238,20 +243,6 @@ class ProcessEpisodeWrap: score = float(episode["reward"].astype(np.float64).sum()) video = episode["image"] cache[str(filename)] = episode - if mode == "eval": - cls.eval_scores.append(score) - cls.eval_lengths.append(length) - # save when enought number of episodes are stored - if len(cls.eval_scores) < config.eval_episode_num: - return - else: - score = sum(cls.eval_scores) / len(cls.eval_scores) - length = sum(cls.eval_lengths) / len(cls.eval_lengths) - episode_num = len(cls.eval_scores) - cls.eval_scores = [] - cls.eval_lengths = [] - cache.clear() - if mode == "train": total = 0 for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])): @@ -260,16 +251,39 @@ class ProcessEpisodeWrap: else: del cache[key] logger.scalar("dataset_size", total) + # use dataset_size as log step for a condition of envs > 1 + log_step = total * config.action_repeat + elif mode == "eval": + # start saving episodes for evaluation + if cls.last_step_at_eval != logger.step: + # keep only last item + while len(cache) > 1: + # FIFO + cache.popitem() + cls.eval_scores = [] + cls.eval_lengths = [] + cls.eval_done = False + cls.last_step_at_eval = logger.step + + cls.eval_scores.append(score) + cls.eval_lengths.append(length) + # ignore if number of eval episodes exceeds eval_episode_num + if len(cls.eval_scores) < config.eval_episode_num or cls.eval_done: + return + score = sum(cls.eval_scores) / len(cls.eval_scores) + length = sum(cls.eval_lengths) / len(cls.eval_lengths) + episode_num = len(cls.eval_scores) + log_step = logger.step + logger.video(f"{mode}_policy", video[None]) + cls.eval_done = True + print(f"{mode.title()} episode has {length} steps and return {score:.1f}.") logger.scalar(f"{mode}_return", score) logger.scalar(f"{mode}_length", length) logger.scalar( f"{mode}_episodes", len(cache) if mode == "train" else episode_num ) - if mode == "eval" or config.expl_gifs: - # only last video in eval videos is preservad - logger.video(f"{mode}_policy", video[None]) - logger.write() + logger.write(step=log_step) def main(config): @@ -329,7 +343,6 @@ def main(config): return {"action": action, "logprob": logprob}, None tools.simulate(random_agent, train_envs, prefill) - tools.simulate(random_agent, eval_envs, episodes=1) logger.step = config.action_repeat * count_steps(config.traindir) print("Simulate agent.") @@ -345,10 +358,10 @@ def main(config): while agent._step < config.steps: logger.write() print("Start evaluation.") - video_pred = agent._wm.video_pred(next(eval_dataset)) - logger.video("eval_openl", to_np(video_pred)) eval_policy = functools.partial(agent, training=False) tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num) + video_pred = agent._wm.video_pred(next(eval_dataset)) + logger.video("eval_openl", to_np(video_pred)) print("Start training.") state = tools.simulate(agent, train_envs, config.eval_every, state=state) torch.save(agent.state_dict(), logdir / "latest_model.pt") diff --git a/tools.py b/tools.py index a73b238..db7e911 100644 --- a/tools.py +++ b/tools.py @@ -1,4 +1,5 @@ import datetime +import collections import io import json import pathlib @@ -74,24 +75,26 @@ class Logger: def video(self, name, value): self._videos[name] = np.array(value) - def write(self, fps=False): + def write(self, fps=False, step=False): + if not step: + step = self.step scalars = list(self._scalars.items()) if fps: - scalars.append(('fps', self._compute_fps(self.step))) - print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars)) + scalars.append(('fps', self._compute_fps(step))) + print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars)) with (self._logdir / 'metrics.jsonl').open('a') as f: - f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n') + f.write(json.dumps({'step': step, ** dict(scalars)}) + '\n') for name, value in scalars: - self._writer.add_scalar('scalars/' + name, value, self.step) + self._writer.add_scalar('scalars/' + name, value, step) for name, value in self._images.items(): - self._writer.add_image(name, value, self.step) + self._writer.add_image(name, value, step) for name, value in self._videos.items(): name = name if isinstance(name, str) else name.decode('utf-8') if np.issubdtype(value.dtype, np.floating): value = np.clip(255 * value, 0, 255).astype(np.uint8) B, T, H, W, C = value.shape value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W)) - self._writer.add_video(name, value, self.step, 16) + self._writer.add_video(name, value, step, 16) self._writer.flush() self._scalars = {} @@ -215,7 +218,7 @@ def sample_episodes(episodes, length=None, balance=False, seed=0): def load_episodes(directory, limit=None, reverse=True): directory = pathlib.Path(directory).expanduser() - episodes = {} + episodes = collections.OrderedDict() total = 0 if reverse: for filename in reversed(sorted(directory.glob('*.npz'))): @@ -677,15 +680,13 @@ class Every: def __call__(self, step): if not self._every: - return False + return 0 if self._last is None: self._last = step - return True - if step >= self._last + self._every: - self._last += self._every - return True - return False - + return 1 + count = int((step - self._last) / self._every) + self._last += self._every * count + return count class Once: