diff --git a/dreamer.py b/dreamer.py index 8c5a91a..05ec8ed 100644 --- a/dreamer.py +++ b/dreamer.py @@ -150,7 +150,6 @@ class Dreamer(nn.Module): return tools.OneHotDist(probs=probs).sample() else: return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1) - raise NotImplementedError(self._config.action_noise) def _train(self, data): metrics = {} @@ -181,7 +180,7 @@ def make_dataset(episodes, config): return dataset -def make_env(config, logger, mode, train_eps, eval_eps): +def make_env(config, mode): suite, task = config.task.split("_", 1) if suite == "dmc": import envs.dmc as dmc @@ -260,7 +259,7 @@ def main(config): else: directory = config.evaldir eval_eps = tools.load_episodes(directory, limit=1) - make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps) + make = lambda mode: make_env(config, mode) train_envs = [make("train") for _ in range(config.envs)] eval_envs = [make("eval") for _ in range(config.envs)] acts = train_envs[0].action_space diff --git a/tools.py b/tools.py index 4fcfdb6..4be88ef 100644 --- a/tools.py +++ b/tools.py @@ -123,7 +123,7 @@ class Logger: def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, steps=0, episodes=0, state=None): - # Initialize or unpack simulation state. + # initialize or unpack simulation state if state is None: step, episode = 0, 0 done = np.ones(len(envs), bool) @@ -134,7 +134,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s else: step, episode, done, length, obs, agent_state, reward = state while (steps and step < steps) or (episodes and episode < episodes): - # Reset envs if necessary. + # reset envs if necessary if done.any(): indices = [index for index, d in enumerate(done) if d] results = [envs[i].reset() for i in indices] @@ -148,7 +148,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s add_to_cache(cache, envs[i].id, t) for index, result in zip(indices, results): obs[index] = result - # Step agents. + # step agents obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} action, agent_state = agent(obs, done, agent_state) if isinstance(action, dict): @@ -159,7 +159,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s else: action = np.array(action) assert len(action) == len(envs) - # Step envs. + # step envs results = [e.step(a) for e, a in zip(envs, action)] obs, reward, done = zip(*[p[:3] for p in results]) obs = list(obs) @@ -169,7 +169,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s length += 1 step += len(envs) length *= 1 - done - # Add to cache + # add to cache for a, result, env in zip(action, results, envs): o, r, d, info = result o = {k: convert(v) for k, v in o.items()} @@ -267,7 +267,7 @@ class CollectDataset: def reset(self): obs = self._env.reset() transition = obs.copy() - # Missing keys will be filled with a zeroed out version of the first + # missing keys will be filled with a zeroed out version of the first # transition, because we do not know what action information the agent will # pass yet. transition["reward"] = 0.0