erased unused lines of code
This commit is contained in:
parent
d1f4d5c709
commit
106317015d
@ -150,7 +150,6 @@ class Dreamer(nn.Module):
|
|||||||
return tools.OneHotDist(probs=probs).sample()
|
return tools.OneHotDist(probs=probs).sample()
|
||||||
else:
|
else:
|
||||||
return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1)
|
return torch.clip(torchd.normal.Normal(action, amount).sample(), -1, 1)
|
||||||
raise NotImplementedError(self._config.action_noise)
|
|
||||||
|
|
||||||
def _train(self, data):
|
def _train(self, data):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
@ -181,7 +180,7 @@ def make_dataset(episodes, config):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def make_env(config, logger, mode, train_eps, eval_eps):
|
def make_env(config, mode):
|
||||||
suite, task = config.task.split("_", 1)
|
suite, task = config.task.split("_", 1)
|
||||||
if suite == "dmc":
|
if suite == "dmc":
|
||||||
import envs.dmc as dmc
|
import envs.dmc as dmc
|
||||||
@ -260,7 +259,7 @@ def main(config):
|
|||||||
else:
|
else:
|
||||||
directory = config.evaldir
|
directory = config.evaldir
|
||||||
eval_eps = tools.load_episodes(directory, limit=1)
|
eval_eps = tools.load_episodes(directory, limit=1)
|
||||||
make = lambda mode: make_env(config, logger, mode, train_eps, eval_eps)
|
make = lambda mode: make_env(config, mode)
|
||||||
train_envs = [make("train") for _ in range(config.envs)]
|
train_envs = [make("train") for _ in range(config.envs)]
|
||||||
eval_envs = [make("eval") for _ in range(config.envs)]
|
eval_envs = [make("eval") for _ in range(config.envs)]
|
||||||
acts = train_envs[0].action_space
|
acts = train_envs[0].action_space
|
||||||
|
12
tools.py
12
tools.py
@ -123,7 +123,7 @@ class Logger:
|
|||||||
|
|
||||||
|
|
||||||
def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, steps=0, episodes=0, state=None):
|
def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, steps=0, episodes=0, state=None):
|
||||||
# Initialize or unpack simulation state.
|
# initialize or unpack simulation state
|
||||||
if state is None:
|
if state is None:
|
||||||
step, episode = 0, 0
|
step, episode = 0, 0
|
||||||
done = np.ones(len(envs), bool)
|
done = np.ones(len(envs), bool)
|
||||||
@ -134,7 +134,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
|
|||||||
else:
|
else:
|
||||||
step, episode, done, length, obs, agent_state, reward = state
|
step, episode, done, length, obs, agent_state, reward = state
|
||||||
while (steps and step < steps) or (episodes and episode < episodes):
|
while (steps and step < steps) or (episodes and episode < episodes):
|
||||||
# Reset envs if necessary.
|
# reset envs if necessary
|
||||||
if done.any():
|
if done.any():
|
||||||
indices = [index for index, d in enumerate(done) if d]
|
indices = [index for index, d in enumerate(done) if d]
|
||||||
results = [envs[i].reset() for i in indices]
|
results = [envs[i].reset() for i in indices]
|
||||||
@ -148,7 +148,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
|
|||||||
add_to_cache(cache, envs[i].id, t)
|
add_to_cache(cache, envs[i].id, t)
|
||||||
for index, result in zip(indices, results):
|
for index, result in zip(indices, results):
|
||||||
obs[index] = result
|
obs[index] = result
|
||||||
# Step agents.
|
# step agents
|
||||||
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]}
|
||||||
action, agent_state = agent(obs, done, agent_state)
|
action, agent_state = agent(obs, done, agent_state)
|
||||||
if isinstance(action, dict):
|
if isinstance(action, dict):
|
||||||
@ -159,7 +159,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
|
|||||||
else:
|
else:
|
||||||
action = np.array(action)
|
action = np.array(action)
|
||||||
assert len(action) == len(envs)
|
assert len(action) == len(envs)
|
||||||
# Step envs.
|
# step envs
|
||||||
results = [e.step(a) for e, a in zip(envs, action)]
|
results = [e.step(a) for e, a in zip(envs, action)]
|
||||||
obs, reward, done = zip(*[p[:3] for p in results])
|
obs, reward, done = zip(*[p[:3] for p in results])
|
||||||
obs = list(obs)
|
obs = list(obs)
|
||||||
@ -169,7 +169,7 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s
|
|||||||
length += 1
|
length += 1
|
||||||
step += len(envs)
|
step += len(envs)
|
||||||
length *= 1 - done
|
length *= 1 - done
|
||||||
# Add to cache
|
# add to cache
|
||||||
for a, result, env in zip(action, results, envs):
|
for a, result, env in zip(action, results, envs):
|
||||||
o, r, d, info = result
|
o, r, d, info = result
|
||||||
o = {k: convert(v) for k, v in o.items()}
|
o = {k: convert(v) for k, v in o.items()}
|
||||||
@ -267,7 +267,7 @@ class CollectDataset:
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
obs = self._env.reset()
|
obs = self._env.reset()
|
||||||
transition = obs.copy()
|
transition = obs.copy()
|
||||||
# Missing keys will be filled with a zeroed out version of the first
|
# missing keys will be filled with a zeroed out version of the first
|
||||||
# transition, because we do not know what action information the agent will
|
# transition, because we do not know what action information the agent will
|
||||||
# pass yet.
|
# pass yet.
|
||||||
transition["reward"] = 0.0
|
transition["reward"] = 0.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user