fix bug when using envs > 1
This commit is contained in:
parent
cd935b7dd9
commit
55ed69bdf7
@ -70,7 +70,6 @@ defaults:
|
|||||||
batch_size: 16
|
batch_size: 16
|
||||||
batch_length: 64
|
batch_length: 64
|
||||||
train_ratio: 512
|
train_ratio: 512
|
||||||
train_steps: 1
|
|
||||||
pretrain: 100
|
pretrain: 100
|
||||||
model_lr: 1e-4
|
model_lr: 1e-4
|
||||||
opt_eps: 1e-8
|
opt_eps: 1e-8
|
||||||
@ -136,6 +135,5 @@ debug:
|
|||||||
debug: True
|
debug: True
|
||||||
pretrain: 1
|
pretrain: 1
|
||||||
prefill: 1
|
prefill: 1
|
||||||
train_steps: 1
|
|
||||||
batch_size: 10
|
batch_size: 10
|
||||||
batch_length: 20
|
batch_length: 20
|
||||||
|
59
dreamer.py
59
dreamer.py
@ -39,6 +39,7 @@ class Dreamer(nn.Module):
|
|||||||
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
|
self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
|
||||||
self._metrics = {}
|
self._metrics = {}
|
||||||
self._step = count_steps(config.traindir)
|
self._step = count_steps(config.traindir)
|
||||||
|
self._update_count = 0
|
||||||
# Schedules.
|
# Schedules.
|
||||||
config.actor_entropy = lambda x=config.actor_entropy: tools.schedule(
|
config.actor_entropy = lambda x=config.actor_entropy: tools.schedule(
|
||||||
x, self._step
|
x, self._step
|
||||||
@ -75,14 +76,16 @@ class Dreamer(nn.Module):
|
|||||||
state[0][key][i] *= mask[i]
|
state[0][key][i] *= mask[i]
|
||||||
for i in range(len(state[1])):
|
for i in range(len(state[1])):
|
||||||
state[1][i] *= mask[i]
|
state[1][i] *= mask[i]
|
||||||
if training and self._should_train(step):
|
if training:
|
||||||
steps = (
|
steps = (
|
||||||
self._config.pretrain
|
self._config.pretrain
|
||||||
if self._should_pretrain()
|
if self._should_pretrain()
|
||||||
else self._config.train_steps
|
else self._should_train(step)
|
||||||
)
|
)
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
self._train(next(self._dataset))
|
self._train(next(self._dataset))
|
||||||
|
self._update_count += 1
|
||||||
|
self._metrics["update_count"] = self._update_count
|
||||||
if self._should_log(step):
|
if self._should_log(step):
|
||||||
for name, values in self._metrics.items():
|
for name, values in self._metrics.items():
|
||||||
self._logger.scalar(name, float(np.mean(values)))
|
self._logger.scalar(name, float(np.mean(values)))
|
||||||
@ -227,6 +230,8 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
|||||||
class ProcessEpisodeWrap:
|
class ProcessEpisodeWrap:
|
||||||
eval_scores = []
|
eval_scores = []
|
||||||
eval_lengths = []
|
eval_lengths = []
|
||||||
|
last_step_at_eval = -1
|
||||||
|
eval_done = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode):
|
def process_episode(cls, config, logger, mode, train_eps, eval_eps, episode):
|
||||||
@ -238,20 +243,6 @@ class ProcessEpisodeWrap:
|
|||||||
score = float(episode["reward"].astype(np.float64).sum())
|
score = float(episode["reward"].astype(np.float64).sum())
|
||||||
video = episode["image"]
|
video = episode["image"]
|
||||||
cache[str(filename)] = episode
|
cache[str(filename)] = episode
|
||||||
if mode == "eval":
|
|
||||||
cls.eval_scores.append(score)
|
|
||||||
cls.eval_lengths.append(length)
|
|
||||||
# save when enought number of episodes are stored
|
|
||||||
if len(cls.eval_scores) < config.eval_episode_num:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
score = sum(cls.eval_scores) / len(cls.eval_scores)
|
|
||||||
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
|
|
||||||
episode_num = len(cls.eval_scores)
|
|
||||||
cls.eval_scores = []
|
|
||||||
cls.eval_lengths = []
|
|
||||||
cache.clear()
|
|
||||||
|
|
||||||
if mode == "train":
|
if mode == "train":
|
||||||
total = 0
|
total = 0
|
||||||
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
for key, ep in reversed(sorted(cache.items(), key=lambda x: x[0])):
|
||||||
@ -260,16 +251,39 @@ class ProcessEpisodeWrap:
|
|||||||
else:
|
else:
|
||||||
del cache[key]
|
del cache[key]
|
||||||
logger.scalar("dataset_size", total)
|
logger.scalar("dataset_size", total)
|
||||||
|
# use dataset_size as log step for a condition of envs > 1
|
||||||
|
log_step = total * config.action_repeat
|
||||||
|
elif mode == "eval":
|
||||||
|
# start saving episodes for evaluation
|
||||||
|
if cls.last_step_at_eval != logger.step:
|
||||||
|
# keep only last item
|
||||||
|
while len(cache) > 1:
|
||||||
|
# FIFO
|
||||||
|
cache.popitem()
|
||||||
|
cls.eval_scores = []
|
||||||
|
cls.eval_lengths = []
|
||||||
|
cls.eval_done = False
|
||||||
|
cls.last_step_at_eval = logger.step
|
||||||
|
|
||||||
|
cls.eval_scores.append(score)
|
||||||
|
cls.eval_lengths.append(length)
|
||||||
|
# ignore if number of eval episodes exceeds eval_episode_num
|
||||||
|
if len(cls.eval_scores) < config.eval_episode_num or cls.eval_done:
|
||||||
|
return
|
||||||
|
score = sum(cls.eval_scores) / len(cls.eval_scores)
|
||||||
|
length = sum(cls.eval_lengths) / len(cls.eval_lengths)
|
||||||
|
episode_num = len(cls.eval_scores)
|
||||||
|
log_step = logger.step
|
||||||
|
logger.video(f"{mode}_policy", video[None])
|
||||||
|
cls.eval_done = True
|
||||||
|
|
||||||
print(f"{mode.title()} episode has {length} steps and return {score:.1f}.")
|
print(f"{mode.title()} episode has {length} steps and return {score:.1f}.")
|
||||||
logger.scalar(f"{mode}_return", score)
|
logger.scalar(f"{mode}_return", score)
|
||||||
logger.scalar(f"{mode}_length", length)
|
logger.scalar(f"{mode}_length", length)
|
||||||
logger.scalar(
|
logger.scalar(
|
||||||
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
|
f"{mode}_episodes", len(cache) if mode == "train" else episode_num
|
||||||
)
|
)
|
||||||
if mode == "eval" or config.expl_gifs:
|
logger.write(step=log_step)
|
||||||
# only last video in eval videos is preservad
|
|
||||||
logger.video(f"{mode}_policy", video[None])
|
|
||||||
logger.write()
|
|
||||||
|
|
||||||
|
|
||||||
def main(config):
|
def main(config):
|
||||||
@ -329,7 +343,6 @@ def main(config):
|
|||||||
return {"action": action, "logprob": logprob}, None
|
return {"action": action, "logprob": logprob}, None
|
||||||
|
|
||||||
tools.simulate(random_agent, train_envs, prefill)
|
tools.simulate(random_agent, train_envs, prefill)
|
||||||
tools.simulate(random_agent, eval_envs, episodes=1)
|
|
||||||
logger.step = config.action_repeat * count_steps(config.traindir)
|
logger.step = config.action_repeat * count_steps(config.traindir)
|
||||||
|
|
||||||
print("Simulate agent.")
|
print("Simulate agent.")
|
||||||
@ -345,10 +358,10 @@ def main(config):
|
|||||||
while agent._step < config.steps:
|
while agent._step < config.steps:
|
||||||
logger.write()
|
logger.write()
|
||||||
print("Start evaluation.")
|
print("Start evaluation.")
|
||||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
|
||||||
logger.video("eval_openl", to_np(video_pred))
|
|
||||||
eval_policy = functools.partial(agent, training=False)
|
eval_policy = functools.partial(agent, training=False)
|
||||||
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
|
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
|
||||||
|
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||||
|
logger.video("eval_openl", to_np(video_pred))
|
||||||
print("Start training.")
|
print("Start training.")
|
||||||
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
||||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||||
|
31
tools.py
31
tools.py
@ -1,4 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import collections
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -74,24 +75,26 @@ class Logger:
|
|||||||
def video(self, name, value):
|
def video(self, name, value):
|
||||||
self._videos[name] = np.array(value)
|
self._videos[name] = np.array(value)
|
||||||
|
|
||||||
def write(self, fps=False):
|
def write(self, fps=False, step=False):
|
||||||
|
if not step:
|
||||||
|
step = self.step
|
||||||
scalars = list(self._scalars.items())
|
scalars = list(self._scalars.items())
|
||||||
if fps:
|
if fps:
|
||||||
scalars.append(('fps', self._compute_fps(self.step)))
|
scalars.append(('fps', self._compute_fps(step)))
|
||||||
print(f'[{self.step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
|
print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in scalars))
|
||||||
with (self._logdir / 'metrics.jsonl').open('a') as f:
|
with (self._logdir / 'metrics.jsonl').open('a') as f:
|
||||||
f.write(json.dumps({'step': self.step, ** dict(scalars)}) + '\n')
|
f.write(json.dumps({'step': step, ** dict(scalars)}) + '\n')
|
||||||
for name, value in scalars:
|
for name, value in scalars:
|
||||||
self._writer.add_scalar('scalars/' + name, value, self.step)
|
self._writer.add_scalar('scalars/' + name, value, step)
|
||||||
for name, value in self._images.items():
|
for name, value in self._images.items():
|
||||||
self._writer.add_image(name, value, self.step)
|
self._writer.add_image(name, value, step)
|
||||||
for name, value in self._videos.items():
|
for name, value in self._videos.items():
|
||||||
name = name if isinstance(name, str) else name.decode('utf-8')
|
name = name if isinstance(name, str) else name.decode('utf-8')
|
||||||
if np.issubdtype(value.dtype, np.floating):
|
if np.issubdtype(value.dtype, np.floating):
|
||||||
value = np.clip(255 * value, 0, 255).astype(np.uint8)
|
value = np.clip(255 * value, 0, 255).astype(np.uint8)
|
||||||
B, T, H, W, C = value.shape
|
B, T, H, W, C = value.shape
|
||||||
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
|
value = value.transpose(1, 4, 2, 0, 3).reshape((1, T, C, H, B*W))
|
||||||
self._writer.add_video(name, value, self.step, 16)
|
self._writer.add_video(name, value, step, 16)
|
||||||
|
|
||||||
self._writer.flush()
|
self._writer.flush()
|
||||||
self._scalars = {}
|
self._scalars = {}
|
||||||
@ -215,7 +218,7 @@ def sample_episodes(episodes, length=None, balance=False, seed=0):
|
|||||||
|
|
||||||
def load_episodes(directory, limit=None, reverse=True):
|
def load_episodes(directory, limit=None, reverse=True):
|
||||||
directory = pathlib.Path(directory).expanduser()
|
directory = pathlib.Path(directory).expanduser()
|
||||||
episodes = {}
|
episodes = collections.OrderedDict()
|
||||||
total = 0
|
total = 0
|
||||||
if reverse:
|
if reverse:
|
||||||
for filename in reversed(sorted(directory.glob('*.npz'))):
|
for filename in reversed(sorted(directory.glob('*.npz'))):
|
||||||
@ -677,15 +680,13 @@ class Every:
|
|||||||
|
|
||||||
def __call__(self, step):
|
def __call__(self, step):
|
||||||
if not self._every:
|
if not self._every:
|
||||||
return False
|
return 0
|
||||||
if self._last is None:
|
if self._last is None:
|
||||||
self._last = step
|
self._last = step
|
||||||
return True
|
return 1
|
||||||
if step >= self._last + self._every:
|
count = int((step - self._last) / self._every)
|
||||||
self._last += self._every
|
self._last += self._every * count
|
||||||
return True
|
return count
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class Once:
|
class Once:
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user