separated cache management of episode from env
This commit is contained in:
parent
88514ec022
commit
9ca5082da3
98
dreamer.py
98
dreamer.py
@ -36,7 +36,8 @@ class Dreamer(nn.Module):
|
||||
self._should_reset = tools.Every(config.reset_every)
|
||||
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
|
||||
self._metrics = {}
|
||||
self._step = count_steps(config.traindir)
|
||||
# this is update step
|
||||
self._step = logger.step // config.action_repeat
|
||||
self._update_count = 0
|
||||
# Schedules.
|
||||
config.actor_entropy = lambda x=config.actor_entropy: tools.schedule(
|
||||
@ -226,82 +227,23 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
raise NotImplementedError(suite)
|
||||
env = wrappers.TimeLimit(env, config.time_limit)
|
||||
env = wrappers.SelectAction(env, key="action")
|
||||
if (mode == "train") or (mode == "eval"):
|
||||
callbacks = [
|
||||
functools.partial(
|
||||
ProcessEpisodeWrap.process_episode,
|
||||
config,
|
||||
logger,
|
||||
mode,
|
||||
train_eps,
|
||||
eval_eps,
|
||||
)
|
||||
]
|
||||
env = wrappers.CollectDataset(env, mode, train_eps, callbacks=callbacks)
|
||||
env = wrappers.UUID(env)
|
||||
# if (mode == "train") or (mode == "eval"):
|
||||
# callbacks = [
|
||||
# functools.partial(
|
||||
# ProcessEpisodeWrap.process_episode,
|
||||
# config,
|
||||
# logger,
|
||||
# mode,
|
||||
# train_eps,
|
||||
# eval_eps,
|
||||
# )
|
||||
# ]
|
||||
# env = wrappers.CollectDataset(env, mode, train_eps, callbacks=callbacks)
|
||||
env = wrappers.RewardObs(env)
|
||||
return env
|
||||
|
||||
|
||||
class ProcessEpisodeWrap:
|
||||
eval_scores = []
|
||||
eval_lengths = []
|
||||
last_step_at_eval = -1
|
||||
eval_done = False
|
||||
|
||||
@classmethod
|
||||
def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode):
|
||||
directory = dict(train=config.traindir, eval=config.evaldir)[mode]
|
||||
cache = dict(train=train_eps, eval=eval_eps)[mode]
|
||||
# this saved episodes is given as train_eps or eval_eps from next call
|
||||
filename = tools.save_episodes(directory, [episode])[0]
|
||||
length = len(episode["reward"]) - 1
|
||||
score = float(episode["reward"].astype(np.float64).sum())
|
||||
video = episode["image"]
|
||||
# add new episode
|
||||
cache[str(filename)] = episode
|
||||
if mode == "train":
|
||||
step_in_dataset = 0
|
||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||
if (
|
||||
not config.dataset_size
|
||||
or step_in_dataset + (len(ep["reward"]) - 1) <= config.dataset_size
|
||||
):
|
||||
step_in_dataset += len(ep["reward"]) - 1
|
||||
else:
|
||||
del cache[key]
|
||||
logger.scalar("dataset_size", step_in_dataset)
|
||||
elif mode == "eval":
|
||||
# keep only last item for saving memory
|
||||
while len(cache) > 1:
|
||||
# FIFO
|
||||
cache.popitem()
|
||||
# start counting scores for evaluation
|
||||
if cls.last_step_at_eval != logger.step:
|
||||
cls.eval_scores = []
|
||||
cls.eval_lengths = []
|
||||
cls.eval_done = False
|
||||
cls.last_step_at_eval = logger.step
|
||||
|
||||
cls.eval_scores.append(score)
|
||||
cls.eval_lengths.append(length)
|
||||
# ignore if number of eval episodes exceeds eval_episode_num
|
||||
if len(cls.eval_scores) < config.eval_episode_num or cls.eval_done:
|
||||
return
|
||||
score = sum(cls.eval_scores) / len(cls.eval_scores)
|
||||
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
|
||||
episode_num = len(cls.eval_scores)
|
||||
logger.video(f"{mode}_policy", video[None])
|
||||
cls.eval_done = True
|
||||
|
||||
print(f"{mode.title()} episode has {length} steps and return {score:.1f}.")
|
||||
logger.scalar(f"{mode}_return", score)
|
||||
logger.scalar(f"{mode}_length", length)
|
||||
logger.scalar(
|
||||
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
|
||||
)
|
||||
logger.write(step=logger.step)
|
||||
|
||||
|
||||
def main(config):
|
||||
logdir = pathlib.Path(config.logdir).expanduser()
|
||||
config.traindir = config.traindir or logdir / "train_eps"
|
||||
@ -316,6 +258,7 @@ def main(config):
|
||||
config.traindir.mkdir(parents=True, exist_ok=True)
|
||||
config.evaldir.mkdir(parents=True, exist_ok=True)
|
||||
step = count_steps(config.traindir)
|
||||
# step in logger is environmental step
|
||||
logger = tools.Logger(logdir, config.action_repeat * step)
|
||||
|
||||
print("Create envs.")
|
||||
@ -357,8 +300,9 @@ def main(config):
|
||||
logprob = random_actor.log_prob(action)
|
||||
return {"action": action, "logprob": logprob}, None
|
||||
|
||||
state = tools.simulate(random_agent, train_envs, prefill)
|
||||
logger.step = config.action_repeat * count_steps(config.traindir)
|
||||
state = tools.simulate(random_agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=prefill)
|
||||
logger.step += prefill * config.action_repeat
|
||||
print(f"Logger: ({logger.step} steps).")
|
||||
|
||||
print("Simulate agent.")
|
||||
train_dataset = make_dataset(train_eps, config)
|
||||
@ -379,12 +323,12 @@ def main(config):
|
||||
logger.write()
|
||||
print("Start evaluation.")
|
||||
eval_policy = functools.partial(agent, training=False)
|
||||
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
|
||||
tools.simulate(eval_policy, eval_envs, eval_eps, config.evaldir, logger, is_eval=True, episodes=config.eval_episode_num)
|
||||
if config.video_pred_log:
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video("eval_openl", to_np(video_pred))
|
||||
print("Start training.")
|
||||
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
||||
state = tools.simulate(agent, train_envs, train_eps, config.traindir, logger, limit=config.dataset_size, steps=config.eval_every, state=state)
|
||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||
for env in train_envs + eval_envs:
|
||||
try:
|
||||
|
@ -1,89 +1,9 @@
|
||||
import datetime
|
||||
import gym
|
||||
import numpy as np
|
||||
import uuid
|
||||
|
||||
|
||||
class CollectDataset:
|
||||
def __init__(
|
||||
self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32
|
||||
):
|
||||
self._env = env
|
||||
self._callbacks = callbacks or ()
|
||||
self._precision = precision
|
||||
self._episode = None
|
||||
self._cache = dict(train=train_eps, eval=eval_eps)[mode]
|
||||
self._temp_name = str(uuid.uuid4())
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
obs = {k: self._convert(v) for k, v in obs.items()}
|
||||
transition = obs.copy()
|
||||
if isinstance(action, dict):
|
||||
transition.update(action)
|
||||
else:
|
||||
transition["action"] = action
|
||||
transition["reward"] = reward
|
||||
transition["discount"] = info.get("discount", np.array(1 - float(done)))
|
||||
self._episode.append(transition)
|
||||
self.add_to_cache(transition)
|
||||
if done:
|
||||
# detele transitions before whole episode is stored
|
||||
del self._cache[self._temp_name]
|
||||
self._temp_name = str(uuid.uuid4())
|
||||
for key, value in self._episode[1].items():
|
||||
if key not in self._episode[0]:
|
||||
self._episode[0][key] = 0 * value
|
||||
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
|
||||
episode = {k: self._convert(v) for k, v in episode.items()}
|
||||
info["episode"] = episode
|
||||
for callback in self._callbacks:
|
||||
callback(episode)
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
transition = obs.copy()
|
||||
# Missing keys will be filled with a zeroed out version of the first
|
||||
# transition, because we do not know what action information the agent will
|
||||
# pass yet.
|
||||
transition["reward"] = 0.0
|
||||
transition["discount"] = 1.0
|
||||
self._episode = [transition]
|
||||
self.add_to_cache(transition)
|
||||
return obs
|
||||
|
||||
def add_to_cache(self, transition):
|
||||
if self._temp_name not in self._cache:
|
||||
self._cache[self._temp_name] = dict()
|
||||
for key, val in transition.items():
|
||||
self._cache[self._temp_name][key] = [self._convert(val)]
|
||||
else:
|
||||
for key, val in transition.items():
|
||||
if key not in self._cache[self._temp_name]:
|
||||
# fill missing data(action)
|
||||
self._cache[self._temp_name][key] = [self._convert(0 * val)]
|
||||
self._cache[self._temp_name][key].append(self._convert(val))
|
||||
else:
|
||||
self._cache[self._temp_name][key].append(self._convert(val))
|
||||
|
||||
def _convert(self, value):
|
||||
value = np.array(value)
|
||||
if np.issubdtype(value.dtype, np.floating):
|
||||
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.signedinteger):
|
||||
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.uint8):
|
||||
dtype = np.uint8
|
||||
elif np.issubdtype(value.dtype, bool):
|
||||
dtype = bool
|
||||
else:
|
||||
raise NotImplementedError(value.dtype)
|
||||
return value.astype(dtype)
|
||||
|
||||
|
||||
class TimeLimit:
|
||||
def __init__(self, env, duration):
|
||||
self._env = env
|
||||
@ -208,3 +128,17 @@ class SelectAction:
|
||||
|
||||
def step(self, action):
|
||||
return self._env.step(action[self._key])
|
||||
|
||||
class UUID:
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def reset(self):
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
|
||||
return self._env.reset()
|
||||
|
167
tools.py
167
tools.py
@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import collections
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
@ -121,7 +122,7 @@ class Logger:
|
||||
self._writer.add_video(name, value, step, 16)
|
||||
|
||||
|
||||
def simulate(agent, envs, steps=0, episodes=0, state=None):
|
||||
def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, steps=0, episodes=0, state=None):
|
||||
# Initialize or unpack simulation state.
|
||||
if state is None:
|
||||
step, episode = 0, 0
|
||||
@ -137,6 +138,14 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
||||
if done.any():
|
||||
indices = [index for index, d in enumerate(done) if d]
|
||||
results = [envs[i].reset() for i in indices]
|
||||
for i in indices:
|
||||
t = results[i].copy()
|
||||
t = {k: convert(v) for k, v in t.items()}
|
||||
# action will be added to transition in add_to_cache
|
||||
t["reward"] = 0.0
|
||||
t["discount"] = 1.0
|
||||
# initial state should be added to cache
|
||||
add_to_cache(cache, envs[i].id, t)
|
||||
for index, result in zip(indices, results):
|
||||
obs[index] = result
|
||||
reward = [reward[i] * (1 - done[i]) for i in range(len(envs))]
|
||||
@ -161,26 +170,165 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
||||
length += 1
|
||||
step += len(envs)
|
||||
length *= 1 - done
|
||||
# Add to cache
|
||||
for a, result, env in zip(action, results, envs):
|
||||
o, r, d, info = result
|
||||
o = {k: convert(v) for k, v in o.items()}
|
||||
transition = o.copy()
|
||||
if isinstance(a, dict):
|
||||
transition.update(a)
|
||||
else:
|
||||
transition["action"] = a
|
||||
transition["reward"] = r
|
||||
transition["discount"] = info.get("discount", np.array(1 - float(d)))
|
||||
add_to_cache(cache, env.id, transition)
|
||||
|
||||
if done.any():
|
||||
indices = [index for index, d in enumerate(done) if d]
|
||||
# logging for done episode
|
||||
for i in indices:
|
||||
save_episodes(directory, {envs[i].id: cache[envs[i].id]})
|
||||
length = len(cache[envs[i].id]["reward"]) - 1
|
||||
score = float(np.array(cache[envs[i].id]["reward"]).sum())
|
||||
video = cache[envs[i].id]["image"]
|
||||
if not is_eval:
|
||||
step_in_dataset = erase_over_episodes(cache, limit)
|
||||
logger.scalar(f"dataset_size", step_in_dataset)
|
||||
logger.scalar(f"train_return", score)
|
||||
logger.scalar(f"train_length", length)
|
||||
logger.scalar(f"train_episodes", len(cache))
|
||||
logger.write(step=logger.step)
|
||||
else:
|
||||
if not 'eval_lengths' in locals():
|
||||
eval_lengths = []
|
||||
eval_scores = []
|
||||
eval_done = False
|
||||
# start counting scores for evaluation
|
||||
eval_scores.append(score)
|
||||
eval_lengths.append(length)
|
||||
|
||||
score = sum(eval_scores) / len(eval_scores)
|
||||
length = sum(eval_lengths) / len(eval_lengths)
|
||||
logger.video(f"eval_policy", np.array(video)[None])
|
||||
|
||||
if len(eval_scores) >= episodes and not eval_done:
|
||||
logger.scalar(f"eval_return", score)
|
||||
logger.scalar(f"eval_length", length)
|
||||
logger.scalar(f"eval_episodes", len(eval_scores))
|
||||
logger.write(step=logger.step)
|
||||
eval_done = True
|
||||
if is_eval:
|
||||
# keep only last item for saving memory. this cache is used for video_pred later
|
||||
while len(cache) > 1:
|
||||
# FIFO
|
||||
cache.popitem(last=False)
|
||||
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
|
||||
|
||||
|
||||
class CollectDataset:
|
||||
def __init__(
|
||||
self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32
|
||||
):
|
||||
self._env = env
|
||||
self._callbacks = callbacks or ()
|
||||
self._precision = precision
|
||||
self._episode = None
|
||||
self._cache = dict(train=train_eps, eval=eval_eps)[mode]
|
||||
self._temp_name = str(uuid.uuid4())
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
obs = {k: self._convert(v) for k, v in obs.items()}
|
||||
transition = obs.copy()
|
||||
if isinstance(action, dict):
|
||||
transition.update(action)
|
||||
else:
|
||||
transition["action"] = action
|
||||
transition["reward"] = reward
|
||||
transition["discount"] = info.get("discount", np.array(1 - float(done)))
|
||||
self._episode.append(transition)
|
||||
self.add_to_cache(transition)
|
||||
if done:
|
||||
# detele transitions before whole episode is stored
|
||||
del self._cache[self._temp_name]
|
||||
self._temp_name = str(uuid.uuid4())
|
||||
for key, value in self._episode[1].items():
|
||||
if key not in self._episode[0]:
|
||||
self._episode[0][key] = 0 * value
|
||||
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
|
||||
episode = {k: self._convert(v) for k, v in episode.items()}
|
||||
info["episode"] = episode
|
||||
for callback in self._callbacks:
|
||||
callback(episode)
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
transition = obs.copy()
|
||||
# Missing keys will be filled with a zeroed out version of the first
|
||||
# transition, because we do not know what action information the agent will
|
||||
# pass yet.
|
||||
transition["reward"] = 0.0
|
||||
transition["discount"] = 1.0
|
||||
self._episode = [transition]
|
||||
self.add_to_cache(transition)
|
||||
return obs
|
||||
|
||||
def add_to_cache(cache, id, transition):
|
||||
if id not in cache:
|
||||
cache[id] = dict()
|
||||
for key, val in transition.items():
|
||||
cache[id][key] = [convert(val)]
|
||||
else:
|
||||
for key, val in transition.items():
|
||||
if key not in cache[id]:
|
||||
# fill missing data(action, etc.) at second time
|
||||
cache[id][key] = [convert(0 * val)]
|
||||
cache[id][key].append(convert(val))
|
||||
else:
|
||||
cache[id][key].append(convert(val))
|
||||
|
||||
def erase_over_episodes(cache, dataset_size):
|
||||
step_in_dataset = 0
|
||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||
if (
|
||||
not dataset_size
|
||||
or step_in_dataset + (len(ep["reward"]) - 1) <= dataset_size
|
||||
):
|
||||
step_in_dataset += len(ep["reward"]) - 1
|
||||
else:
|
||||
del cache[key]
|
||||
return step_in_dataset
|
||||
|
||||
def convert(value, precision=32):
|
||||
value = np.array(value)
|
||||
if np.issubdtype(value.dtype, np.floating):
|
||||
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
|
||||
elif np.issubdtype(value.dtype, np.signedinteger):
|
||||
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[precision]
|
||||
elif np.issubdtype(value.dtype, np.uint8):
|
||||
dtype = np.uint8
|
||||
elif np.issubdtype(value.dtype, bool):
|
||||
dtype = bool
|
||||
else:
|
||||
raise NotImplementedError(value.dtype)
|
||||
return value.astype(dtype)
|
||||
|
||||
def save_episodes(directory, episodes):
|
||||
directory = pathlib.Path(directory).expanduser()
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||
filenames = []
|
||||
for episode in episodes:
|
||||
identifier = str(uuid.uuid4().hex)
|
||||
for filename, episode in episodes.items():
|
||||
length = len(episode["reward"])
|
||||
filename = directory / f"{timestamp}-{identifier}-{length}.npz"
|
||||
filename = directory / f"{filename}-{length}.npz"
|
||||
with io.BytesIO() as f1:
|
||||
np.savez_compressed(f1, **episode)
|
||||
f1.seek(0)
|
||||
with filename.open("wb") as f2:
|
||||
f2.write(f1.read())
|
||||
filenames.append(filename)
|
||||
return filenames
|
||||
return True
|
||||
|
||||
|
||||
def from_generator(generator, batch_size):
|
||||
@ -244,7 +392,8 @@ def load_episodes(directory, limit=None, reverse=True):
|
||||
except Exception as e:
|
||||
print(f"Could not load episode: {e}")
|
||||
continue
|
||||
episodes[str(filename)] = episode
|
||||
# extract only filename without extension
|
||||
episodes[str(os.path.splitext(os.path.basename(filename))[0])] = episode
|
||||
total += len(episode["reward"]) - 1
|
||||
if limit and total >= limit:
|
||||
break
|
||||
|
Loading…
x
Reference in New Issue
Block a user